In [1]:
%%capture
!pip install tensorflow-addons[tensorflow]

In [None]:
from google.colab import drive
drive.mount("/content/drive")
%cd "/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust"

Mounted at /content/drive
/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust


In [2]:
import os, glob 
import numpy as np
import pandas as pd

In [3]:
import tensorflow as tf

from tensorflow.keras import Model

import tensorflow.image as transforms

In [4]:
def standardize_gray_image(image, label):
    # see https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
    mean_img, std_img = tf.math.reduce_mean(image), tf.math.reduce_std(image)
    num_els = 1.0 * image.shape[0] * image.shape[1]
    adjusted_stddev = tf.maximum(std_img, 1.0/tf.math.sqrt(num_els))
    image = (image - mean_img)/adjusted_stddev
    return image, label

def permute_gray_image(image, label): 
    rimg = tf.reshape(image, [image.shape[0] * image.shape[1]])
    permut = list(np.random.permutation(len(rimg)))
    rimg = tf.gather(rimg, indices=permut)
    image = tf.reshape(rimg, image.shape)
    return image, label

def test_permute(image):
    import matplotlib.pyplot as plt
    # original 
    plt.subplot(131)
    plt.imshow(image)
    plt.title('original')
    # permute pixels
    rimg = tf.reshape(image, -1)
    permut = list(np.random.permutation(len(rimg)))
    rimg = tf.gather(rimg, indices=permut)
    plt.subplot(132)
    plt.imshow(tf.reshape(rimg, (28, 28)))
    plt.title('permuted')
    # recover for checking  
    rimg = tf.gather(rimg, indices=tf.math.invert_permutation(permut))
    plt.subplot(133)
    plt.imshow(tf.reshape(rimg, (28, 28)))
    plt.title('recovered')


In [5]:
def parse_task(full_task, sep='-'):
    # e.g 'MNIST'   -> {'action': 'none', 'task':'MNIST'}
    # e.g 'p-MNIST' -> {'action': 'p', 'task':'MNIST'}
    if sep not in full_task: 
        action, task = 'none', full_task
    else: 
        action, task = full_task.split(sep)
    return task, action
    

def create_data_iter(task, action='none', batch_size=100, return_dict=False, 
                     shuff_buffsz=2000, map_npar=tf.data.AUTOTUNE,
                     ds_prefetch=True, ds_cache=True, return_iters=True):
    task = task.upper() 

    if task in ['MNIST', 'FMNIST']:
        # Load data 
        tf_ds = tf.keras.datasets.mnist if task == 'MNIST' else tf.keras.datasets.fashion_mnist
        x, y, ds = dict(), dict(), dict()
        (x['train'], y['train']), (x['test'], y['test']) = tf_ds.load_data()
        for k in ['train', 'test']:
            x[k] = (x[k]/255.0).astype('float32')
        
        # Create tensor dataset 
        ds = {k: tf.data.Dataset.from_tensor_slices((x[k],y[k])) for k in ['train', 'test']}

        # Perform permutation 
        if action.lower() in ['p', 'perm', 'permuted']: 
            ds = {k: v.map(permute_gray_image, num_parallel_calls=map_npar) for k,v in ds.items()}

        # Perform image standardization 
        ds = {k: v.map(standardize_gray_image, num_parallel_calls=map_npar) for k,v in ds.items()}

        # Prepare for training + batching 
        ds['train'] = ds['train'].shuffle(buffer_size=shuff_buffsz)
        ds = {k: v.batch(batch_size) for k,v in ds.items()}

        # Optional to speedup
        if ds_prefetch: ds = {k: v.prefetch(tf.data.AUTOTUNE) for k,v in ds.items()}
        if ds_cache: ds = {k: v.cache() for k,v in ds.items()}

        # # Create iterators
        if return_iters: ds = {k: iter(v) for k,v in ds.items()}

        if return_dict: return ds
        return ds['train'], ds['test']       
         
    else:
        raise('"%s" task not implemented' %(task))
    

In [11]:
batch_size = 100
task_sequences = ['MNIST', 'p-MNIST', 'p-MNIST', 'p-MNIST']

In [12]:
ds = create_data_iter('MNIST', batch_size=batch_size, return_dict=True, return_iters=False)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
ds = []

for full_task in task_sequences:
    task, action = parse_task(full_task)
    ds.append(create_data_iter(task, action, batch_size=batch_size, return_dict=True))


In [92]:
import tensorflow as tf

from tensorflow.keras import layers, initializers
from tensorflow.keras import Model, Sequential

@tf.custom_gradient
def SignActivation(x):
    y = tf.sign(x)
    def grad(grad_output):
        # use hardtanh gradient (see paper): anything with abs() > 1 -> 0
        return tf.where(tf.abs(x) > 1.0, 0.0, grad_output)
    return y, grad

def Binarize(tensor):
    return tf.sign(tensor)

class BinarizeLinear(layers.Layer):
    '''
    units:          number of units for layer 
    init_type:      'gauss' or 'uniform' for weight initialization in Dense layer
    init_width:     used for initialization widths or stddev for weight initialization
    dropout_rate:   if not None, will use to construct Dropout layer
    act_fun:        if not None, activation function; currently only 'sign' (SignActivation)
    norm_type:      if not None, will use for normalization layer; currently only 'bn' (batchnorm)
    norm_args:      (dict) normalization layer arguments 
    '''
    def __init__(self, units, init_type = 'gauss', init_width = 0.01, 
                 dropout_rate = None, norm_type = 'bn', act_fun = 'sign', 
                 norm_args = dict(momentum=0.1,epsilon=1e-5), 
                 name = 'bfc'):
        
        super(BinarizeLinear, self).__init__(name=name)
        self.units = units
        self.init_type = init_type
        self.init_width = init_width
        self.dropout_rate = dropout_rate
        self.act_fun = act_fun
        self.norm_type = norm_type
        self.norm_args = norm_args

    def get_dense_initializer(self): 
        kernel_init = 'glorot_uniform'
        init_width = self.init_width
        if self.init_type == 'gauss': 
            kernel_init = initializers.RandomNormal(mean=0.0, stddev=init_width)
        if self.init_type == 'uniform': 
            kernel_init = initializers.RandomUniform(minval=-init_width/2, maxval=init_width/2)
        return kernel_init
    
    def build(self, input_shape): 
        self.inp_dim = input_shape 
        
        # Dense linear layer 
        self.fc = layers.Dense(self.units, use_bias=False, activation=None, 
                                kernel_initializer=self.get_dense_initializer())
        self.fc.build(input_shape)

        # Create dropout layer
        if self.dropout_rate:
            self.dropout = layers.Dropout(rate=self.dropout_rate)
        
        # Create normalization layer 
        if self.norm_type:
            if self.norm_type == 'bn': 
                self.norm = layers.BatchNormalization(**self.norm_args)
            else:
                raise('Only "bn" (batchnorm) or None is allowed for normalization at this point')
        
        # Acitvation 
        if self.act_fun:
            if self.act_fun == 'sign':
                self.act = SignActivation
            else:
                raise('Only "sign" (SignActivation) or None is allowed for activation at this point')
        
    def call(self, input):
        org_fc_kernel = tf.identity(self.fc.kernel)
        self.fc.kernel.assign(Binarize(self.fc.kernel))        
        out = self.fc(input)  
        self.fc.kernel.assign(org_fc_kernel)
                
        if self.dropout_rate: 
            out = self.dropout(out)
        
        if self.norm_type:
            out = self.norm(out)
        
        if self.act_fun:
            out = self.act(out)
        
        return out

class BNN(Model):
    '''
    layers_dims:    [(input_height, input_width), hidden_1, hidden_2, ..., output]
    '''
    def __init__(self, layers_dims, **kwargs):
        super(BNN, self).__init__()
        self.layers_dims = layers_dims
        self.num_hidden = len(layers_dims) - 2
        
        self.hidden_args = dict(**kwargs)

        self.output_args = dict(**kwargs)
        self.output_args['act_fun'] = None # no activation at output 
        
        # define layers 
        self.flatten = layers.Flatten(input_shape=layers_dims[0])
        self.bfcs = Sequential([
            BinarizeLinear(layers_dims[i], 
                           **self.hidden_args, 
                           name = 'bfc-%02d' %(i))
            for i in range(1, self.num_hidden+1)
        ])
        self.out = BinarizeLinear(layers_dims[-1], **self.output_args, name='output')

    def call(self, x):
        x = self.flatten(x)
        x = self.bfcs(x)
        x = self.out(x)
        return x

In [88]:
layer_dims = [(28,28), 512, 512, 10]
model = BNN(layer_dims)

In [89]:
# ds = create_data_iter('MNIST', batch_size=batch_size, return_dict=True, return_iters=False)
Xs, Ys = next(iter(ds['train']))

In [90]:
Yhat = model(Xs)

In [29]:
Yhat.shape

TensorShape([100, 10])

In [93]:
for p in model.variables:
    print('{} \t shape={} dim={}'.format(p.name,p.shape,len(p.shape)))

bfc-01/kernel:0 	 shape=(784, 512) dim=2
bfc-01/batch_normalization/gamma:0 	 shape=(512,) dim=1
bfc-01/batch_normalization/beta:0 	 shape=(512,) dim=1
bfc-01/batch_normalization/moving_mean:0 	 shape=(512,) dim=1
bfc-01/batch_normalization/moving_variance:0 	 shape=(512,) dim=1
bfc-02/kernel:0 	 shape=(512, 512) dim=2
bfc-02/batch_normalization/gamma:0 	 shape=(512,) dim=1
bfc-02/batch_normalization/beta:0 	 shape=(512,) dim=1
bfc-02/batch_normalization/moving_mean:0 	 shape=(512,) dim=1
bfc-02/batch_normalization/moving_variance:0 	 shape=(512,) dim=1
bnn_8/output/kernel:0 	 shape=(512, 10) dim=2
bnn_8/output/batch_normalization_5/gamma:0 	 shape=(10,) dim=1
bnn_8/output/batch_normalization_5/beta:0 	 shape=(10,) dim=1
bnn_8/output/batch_normalization_5/moving_mean:0 	 shape=(10,) dim=1
bnn_8/output/batch_normalization_5/moving_variance:0 	 shape=(10,) dim=1


In [94]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops, control_flow_ops, math_ops, state_ops
from tensorflow.python.keras import backend_config

import tensorflow_addons as tfa
from tensorflow_addons.optimizers import DecoupledWeightDecayExtension


class Adam_meta(keras.optimizers.Optimizer):
    '''
    Adam optimizer with `meta` parameter

    PARAMETERS:
    - meta:   meta-plasticity parameter, for now only allows scalar values (in paper allows layer-wise)

    NOTE:
        the rest parameters are similar to original Adam

    TODO:
    - wrap around with [DecoupledWeightDecayExtension] -> Adam_meta_W
    - consider adding `f_meta` as an option
    - consider applying heterogeneity in `meta` like paper

    SOURCE:
    - [Adam-keras](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/optimizer_v2/adam.py)
    - [AdamW-keras](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/keras_adamw/optimizers_v2.py)
    - [Adam_meta-Torch](https://github.com/Laborieux-Axel/SynapticMetaplasticityBNN/blob/master/Continual_Learning_Fig-2abcdefgh-3abcd-5cde/models_utils.py)
    - [DecoupledWeightDecayExtension]: https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/DecoupledWeightDecayExtension

    '''

    def __init__(self,
                 meta=0,
                 learning_rate=0.001,
                 beta_1=0.9,
                 beta_2=0.999,
                 epsilon=1e-8,
                 amsgrad=False, 
                 decay=0,
                 name="Adam-meta",
                 **kwargs):
        # Check for conditions
        if min(meta, learning_rate, epsilon) < 0.0:
            raise ValueError('Invalid "meta" or "learning_rate" or "epsilon". Needs all to be non-negative')
        if min(beta_1, beta_2) < 0.0 or max(beta_1, beta_2) > 1.0:
            raise ValueError('Invalid "beta_1" or "beta_2". Needs both to be within [0,1]')

        # Initialization 
        super(Adam_meta, self).__init__(name, **kwargs)
        
        # Add hyperparameters
        self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
        self._set_hyper('meta', meta)
        self._set_hyper('beta_1', beta_1)
        self._set_hyper('beta_2', beta_2)
        self.decay = decay
        self.epsilon = epsilon or backend_config.epsilon()
        self.amsgrad = amsgrad

    def _create_slots(self, var_list):
        '''Create slots for the first and second moments.
        Exactly similar to [AdamW-keras]
        '''
        for var in var_list:
            self.add_slot(var, 'm') # 1st moment
        for var in var_list:
            self.add_slot(var, 'v') # 2nd moment
        if self.amsgrad:
            for var in var_list:
                self.add_slot(var, 'vhat') # 2nd moment in case AMSGrad
        self._updates_per_iter = len(var_list)


    # @tf.function
    def _resource_apply_dense(self, grad, var):
        '''Update the slots and perform one optimization step for one model variable for metaplasticity
        This is mirroring [AdamW-keras] and [Adam_meta-Torch].
        '''
        var_device, var_dtype = var.device, var.dtype.base_dtype

        # Get slots for 1st and 2nd moments
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')

        # Get hyperparameters
        meta_t = array_ops.identity(self._get_hyper('meta', var_dtype))
        lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        decay_t = ops.convert_to_tensor(self.decay, var_dtype)

        # Compute parameters based on current local step
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step) # B1^t
        beta_2_power = math_ops.pow(beta_2_t, local_step) # B2^t

        # Learning rate bias correction
        # eta_t <- eta * sqrt(1-B2^t) / (1-B1^t)
        lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)

        # Essential ADAM equations
        # m: 1st moment
        # v: 2nd moment
        # g: grad
        # m <- B1 * m + (1 - B1) * g
        # v <- B2 * v + (1 - B2) * g^2
        if self.decay != 0:
            grad = math_ops.add(grad, math_ops.multiply(decay_t,var))
        m_t = state_ops.assign(m, beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking)
        v_t = state_ops.assign(v, beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), use_locking=self._use_locking)

        # Apply AMSGrad if turned on
        # usually var_delta = dX <- m / (sqrt(v or v_hat) + eps)
        # var_delta_denom <- sqrt(v or v_hat) + eps
        # but metaplast will change a bit so only calc denom now
        if self.amsgrad: # v_hat <- max(v_hat, v_t)
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking)
            var_delta_denom = math_ops.sqrt(vhat_t) + epsilon_t
        else:
            var_delta_denom = math_ops.sqrt(v_t) + epsilon_t

        # Metaplasticity
        if len(var.shape) == 1:  # True if bias or BN params, false if weight. TODO: Need to double check
            # X <- X - eta * sqrt(1-B2^t) / (1-B1^t) * m / (sqrt(v or v_hat) + eps)
            # X <- X - eta_t * dX
            # dX <- m / (var_delta_denom = sqrt(v or v_hat) + eps)
            var_t = math_ops.sub(var, lr_t * m_t / var_delta_denom)
        else:
            # the variables will be similar to [Adam_meta-Torch] code and try to mirror paper
            # binary_weight_before_update: Wb <- sign(Wh)
            # condition_consolidation: use_meta <- Uw * Wb > 0.0
            # Uw <- dX
            # Wh <- var
            Wb = math_ops.sign(var)
            use_meta = math_ops.multiply(Wb, m_t) > 0.0 # sign(m_t) = sign(dX)

            # f_meta = 1 - tanh(m * Wh)^2
            f_meta = array_ops.ones_like(var) - math_ops.square(math_ops.tanh(meta_t * var))

            # only use meta-applied m_t when use_meta = True
            # i.e. only use f_meta when Wb * m_t/denom > 0
            decayed_m_t = math_ops.multiply(f_meta, m_t)
            alt_m_t = array_ops.where(use_meta, decayed_m_t, m_t)

            # X <- X - eta_t * dX
            # dX <- (f_meta(X=Wh) if Wb*Uw >0 else 1.0) * m / (var_delta_denom = sqrt(v or v_hat) + eps)
            var_t = math_ops.sub(var, lr_t * alt_m_t / var_delta_denom)

        # Return updates
        var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)
        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        config = super().get_config()
        config.update({
            'learning_rate': self._serialize_hyperparameter('learning_rate'),
            'meta': self._serialize_hyperparameter('meta'),
            'beta_1': self._serialize_hyperparameter('beta_1'),
            'beta_2': self._serialize_hyperparameter('beta_2'),
            'epsilon': self.epsilon,
            'amsgrad': self.amsgrad
        })
        return config

class Adam_meta_W(DecoupledWeightDecayExtension, Adam_meta):
    '''
    NOTE: Untested
    '''
    def __init__(self,
                 weight_decay,
                 meta           = 0,
                 learning_rate  = 0.001,
                 beta_1         = 0.9,
                 beta_2         = 0.999,
                 epsilon        = 1e-8,
                 amsgrad        = False,
                 name           = "Adam_meta_W",
                 **kwargs):
        super().__init__(
            weight_decay,
            meta            = meta,
            learning_rate   = learning_rate,
            beta_1          = beta_1,
            beta_2          = beta_2,
            epsilon         = epsilon,
            amsgrad         = amsgrad,
            name            = name,
            **kwargs)

In [97]:
batch_size = 100
meta = 1.35
learning_rate = 0.005
weight_decay = 1e-7 
gamma 

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

optimizer = Adam_meta(meta = meta, learning_rate = learning_rate, decay = weight_decay)
# optimizer = tfa.optimizers.extend_with_decoupled_weight_decay(Adam_meta)(weight_decay = weight_decay, meta = meta, learning_rate = learning_rate)
# optimizer = Adam_meta_W(weight_decay = weight_decay, meta = meta, learning_rate = learning_rate)
# optimizer = Adam_meta(meta = meta, learning_rate = learning_rate)
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# optimizerW = Adam_meta_W(weight_decay = weight_decay, meta = meta, learning_rate = learning_rate)


In [98]:
layer_dims = [(28,28), 4096, 4096, 10]
layer_dims = [(28,28), 512, 512, 10]
model = BNN(layer_dims, init_type = 'uniform', init_width = 0.1, dropout_rate = None) # dropout_rate=None, norm_type=None)

In [110]:
ds = create_data_iter('MNIST', batch_size=100, return_dict=True, return_iters=False)

In [111]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [142]:
def reset_metrics():
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_accuracy.update_state(labels, predictions)
    

@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)



In [143]:
from tqdm.notebook import tqdm

In [None]:
%%time 
num_epochs = 100
for epoch in tqdm(range(num_epochs)):
    reset_metrics()

    for images, labels in iter(ds['train']):
        train_step(images, labels)

    for images, labels in iter(ds['test']):
        test_step(images, labels)

    print('\t Epoch: %d | Loss: %.4f | Train Accuracy: %.2f | Test Acc: %.2f' \
        %(epoch, train_loss.result(), train_accuracy.result() * 100, test_accuracy.result() * 100))

In [147]:
batch_size = 200
task_sequences = ['MNIST', 'p-MNIST', 'p-MNIST', 'p-MNIST']

In [148]:
all_ds = []
for id, full_task in enumerate(task_sequences):
    task, action = parse_task(full_task)
    ds = create_data_iter(task, action, batch_size=batch_size, return_dict=True, return_iters=False)
    ds['task_id' ] = 'task-%02d' %(id)
    all_ds.append(ds)

In [159]:
batch_size = 100
meta = 1.35
learning_rate = 0.005
weight_decay = 1e-7 

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = Adam_meta(meta = meta, learning_rate = learning_rate, decay = weight_decay)

In [178]:
# layer_dims = [(28,28), 512, 512, 10]
layer_dims = [(28,28), 4000, 4000, 10]
model = BNN(layer_dims, init_type = 'uniform', init_width = 0.1, dropout_rate = None) # dropout_rate=None, norm_type=None)

In [179]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [180]:
def reset_metrics():
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_accuracy.update_state(labels, predictions)
    

@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)



In [None]:
all_metrics['train']

In [176]:
import pandas as pd
pd.DataFrame(all_metrics['test'])

Unnamed: 0,loss,acc,task_id,epoch
0,0.149725,95.389998,task-00,1
1,4.050347,10.960000,task-01,1
2,3.847055,12.930000,task-02,1
3,3.979600,9.580000,task-03,1
4,0.114581,96.439999,task-00,2
...,...,...,...,...
251,7.287422,8.660000,task-03,3
252,2.144198,50.919998,task-00,4
253,8.080349,9.300000,task-01,4
254,7.491025,11.360000,task-02,4


In [181]:
%%time 
num_epochs_per_task = 5
test_accuracy_all_task = []
all_metrics = dict(
    train = [], 
    test = [],
)

def get_current_metrics(loss, acc, task_id, epoch):
    return dict(
        loss = float(loss.result()),
        acc = float(acc.result())*100.0,
        task_id = task_id,
        epoch = epoch
    )

for curr_ds in tqdm(all_ds):
    curr_train_ds = curr_ds['train']
    curr_task_id = curr_ds['task_id']
    for epoch in tqdm(range(num_epochs_per_task)):
        # training
        reset_metrics()
        for images, labels in iter(curr_train_ds):
            train_step(images, labels)
        curr_metrics = get_current_metrics(train_loss, train_accuracy, curr_task_id, epoch+1)
        print('TASK: %s - Epoch %02d -> Train loss: %.4f \t acc: %.2f%%' 
              %(curr_task_id, epoch, curr_metrics['loss'], curr_metrics['acc']))
        all_metrics['train'].append(curr_metrics)

        # testing 
        for ds in all_ds:
            test_taskid = ds['task_id']
            reset_metrics()
            for images, labels in iter(ds['test']):
                test_step(images, labels)
            curr_metrics = get_current_metrics(test_loss, test_accuracy, test_taskid, epoch+1)
            all_metrics['test'].append(curr_metrics)

            print('\t+ Test task %s: %.2f%%' %(test_taskid, curr_metrics['acc']))

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

TASK: task-00 - Epoch 00 -> Train loss: 0.2419 	 acc: 93.46%
	+ Test task task-00: 96.02%
	+ Test task task-01: 9.69%
	+ Test task task-02: 11.16%
	+ Test task task-03: 10.22%
TASK: task-00 - Epoch 01 -> Train loss: 0.0977 	 acc: 97.17%
	+ Test task task-00: 96.37%
	+ Test task task-01: 11.64%
	+ Test task task-02: 12.52%
	+ Test task task-03: 10.75%
TASK: task-00 - Epoch 02 -> Train loss: 0.0675 	 acc: 98.09%
	+ Test task task-00: 95.95%
	+ Test task task-01: 10.75%
	+ Test task task-02: 11.66%
	+ Test task task-03: 12.14%
TASK: task-00 - Epoch 03 -> Train loss: 0.0531 	 acc: 98.49%
	+ Test task task-00: 96.79%
	+ Test task task-01: 5.17%
	+ Test task task-02: 13.28%
	+ Test task task-03: 11.18%
TASK: task-00 - Epoch 04 -> Train loss: 0.0446 	 acc: 98.74%
	+ Test task task-00: 95.89%
	+ Test task task-01: 8.34%
	+ Test task task-02: 10.46%
	+ Test task task-03: 8.95%


  0%|          | 0/5 [00:00<?, ?it/s]

TASK: task-01 - Epoch 00 -> Train loss: 0.2083 	 acc: 94.06%
	+ Test task task-00: 47.09%
	+ Test task task-01: 7.76%
	+ Test task task-02: 9.16%
	+ Test task task-03: 8.03%
TASK: task-01 - Epoch 01 -> Train loss: 0.0723 	 acc: 97.83%
	+ Test task task-00: 42.51%
	+ Test task task-01: 8.39%
	+ Test task task-02: 10.19%
	+ Test task task-03: 10.42%
TASK: task-01 - Epoch 02 -> Train loss: 0.0441 	 acc: 98.77%
	+ Test task task-00: 35.55%
	+ Test task task-01: 9.03%
	+ Test task task-02: 10.95%
	+ Test task task-03: 10.09%
TASK: task-01 - Epoch 03 -> Train loss: 0.0339 	 acc: 99.01%
	+ Test task task-00: 31.89%
	+ Test task task-01: 9.21%
	+ Test task task-02: 10.36%
	+ Test task task-03: 10.46%
TASK: task-01 - Epoch 04 -> Train loss: 0.0321 	 acc: 99.06%
	+ Test task task-00: 31.41%
	+ Test task task-01: 8.62%
	+ Test task task-02: 7.15%
	+ Test task task-03: 12.14%


  0%|          | 0/5 [00:00<?, ?it/s]

TASK: task-02 - Epoch 00 -> Train loss: 0.2039 	 acc: 94.30%
	+ Test task task-00: 7.93%
	+ Test task task-01: 11.99%
	+ Test task task-02: 9.69%
	+ Test task task-03: 9.80%
TASK: task-02 - Epoch 01 -> Train loss: 0.0618 	 acc: 98.13%
	+ Test task task-00: 9.79%
	+ Test task task-01: 11.39%
	+ Test task task-02: 12.03%
	+ Test task task-03: 11.90%
TASK: task-02 - Epoch 02 -> Train loss: 0.0346 	 acc: 99.02%
	+ Test task task-00: 8.88%
	+ Test task task-01: 11.57%
	+ Test task task-02: 8.95%
	+ Test task task-03: 9.42%
TASK: task-02 - Epoch 03 -> Train loss: 0.0280 	 acc: 99.21%
	+ Test task task-00: 13.30%
	+ Test task task-01: 11.14%
	+ Test task task-02: 11.00%
	+ Test task task-03: 12.58%
TASK: task-02 - Epoch 04 -> Train loss: 0.0262 	 acc: 99.18%
	+ Test task task-00: 11.24%
	+ Test task task-01: 11.42%
	+ Test task task-02: 9.64%
	+ Test task task-03: 12.51%


  0%|          | 0/5 [00:00<?, ?it/s]

TASK: task-03 - Epoch 00 -> Train loss: 0.2005 	 acc: 94.52%
	+ Test task task-00: 9.79%
	+ Test task task-01: 10.50%
	+ Test task task-02: 10.50%
	+ Test task task-03: 9.77%
TASK: task-03 - Epoch 01 -> Train loss: 0.0549 	 acc: 98.38%
	+ Test task task-00: 8.53%
	+ Test task task-01: 9.84%
	+ Test task task-02: 10.33%
	+ Test task task-03: 11.41%
TASK: task-03 - Epoch 02 -> Train loss: 0.0275 	 acc: 99.29%
	+ Test task task-00: 9.29%
	+ Test task task-01: 9.23%
	+ Test task task-02: 11.37%
	+ Test task task-03: 10.99%
TASK: task-03 - Epoch 03 -> Train loss: 0.0222 	 acc: 99.32%
	+ Test task task-00: 7.97%
	+ Test task task-01: 9.24%
	+ Test task task-02: 11.37%
	+ Test task task-03: 12.34%


KeyboardInterrupt: ignored

# Testing OG torch version

In [None]:
from google.colab import drive
drive.mount("/content/drive")
%cd "/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust/ext/SynapticMetaplasticityBNN/Continual_Learning_Fig-2abcdefgh-3abcd-5cde/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust/ext/SynapticMetaplasticityBNN/Continual_Learning_Fig-2abcdefgh-3abcd-5cde


In [None]:
!ls ..

CIFAR-features				     README.md
Continual_Learning_Fig-2abcdefgh-3abcd-5cde  requirements.txt
LICENSE					     Stream_Learning_CIFAR10_Fig-4b
MNIST-USPS				     Stream_Learning_FMNIST_Fig-4a
Quadratic_Binary_Task_Fig-5ab


In [None]:
!python main.py --net 'bnn' --hidden-layers 512 512 --lr 0.005 --decay 1e-7 --meta 1.35 --epochs-per-task 10 --task-sequence 'MNIST' 'pMNIST' 'pMNIST' 'pMNIST'

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
BNN(
  (layers): ModuleDict(
    (fc1): BinarizeLinear(in_features=784, out_features=512, bias=False)
    (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): BinarizeLinear(in_features=512, out_features=512, bias=False)
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc3): BinarizeLinear(in_features=512, out_features=10, bias=False)
    (bn3): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:1025.)
  grad.add_(group['weight_decay'], p.data)
Test accuracy: 57487/60000 (95.81%)
Test accuracy: 9518/10000 (95.18%)
Test accuracy: 58249/60000 (97.08%)
Test accuracy: 9591/10000 (95.91%)
Test accura