In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import os
import argparse
import matplotlib

In [None]:
# We submitted these files along with this notebook

from Adam import Adam_Metaplastic
from data_utils import *
from BNN_Conv import *

### Define Task Sequence

In [None]:
# SEQUENCE OF TASKS TO TRAIN ON
task_sequence = ['pFMNIST', 'pFMNIST', 'pFMNIST', 'pFMNIST', 'pFMNIST']

### Load in Training Tasks

In [None]:
# Code borrowed from Laborieux, et al. and slightly modified for our use case.

train_loader_list = []
test_loader_list = []
dset_train_list = []
task_names = []

for idx, task in enumerate(task_sequence):
    if task == 'MNIST':
        train_loader_list.append(mnist_train_loader)
        test_loader_list.append(mnist_test_loader)
        dset_train_list.append(mnist_dset_train)
        task_names.append(task)
    elif task == 'USPS':
        train_loader_list.append(usps_train_loader)
        test_loader_list.append(usps_test_loader)
        dset_train_list.append(usps_dset_train)
        task_names.append(task)
    elif task == 'FMNIST':
        train_loader_list.append(fashion_mnist_train_loader)
        test_loader_list.append(fashion_mnist_test_loader)
        dset_train_list.append(fmnist_dset_train)
        task_names.append(task)
    elif task == 'pMNIST':
        train_loader, test_loader, dset_train = create_permuted_loaders(task[1:])
        train_loader_list.append(train_loader)
        test_loader_list.append(test_loader)
        dset_train_list.append(dset_train)
        task_names.append(task+str(idx+1))
    elif task == 'pFMNIST':
        train_loader, test_loader, dset_train = create_permuted_loaders(task[1:])
        train_loader_list.append(train_loader)
        test_loader_list.append(test_loader)
        dset_train_list.append(dset_train)
        task_names.append(task+str(idx+1))
    elif task == 'pUSPS':
        train_loader, test_loader, dset_train = create_permuted_loaders(task[1:])
        train_loader_list.append(train_loader)
        test_loader_list.append(test_loader)
        dset_train_list.append(dset_train)
        task_names.append(task+str(idx+1))
    elif task == 'animals':
        animals_train_loader, animals_test_loader, animals_dset_train = process_cifar10(task)
        train_loader_list.append(animals_train_loader)
        test_loader_list.append(animals_test_loader)
        dset_train_list.append(animals_dset_train)
        task_names.append('animals')
    elif task == 'vehicles':
        vehicles_train_loader, vehicles_test_loader, vehicles_dset_train = process_cifar10(task)
        train_loader_list.append(vehicles_train_loader)
        test_loader_list.append(vehicles_test_loader)
        dset_train_list.append(vehicles_dset_train)
        task_names.append('vehicles')
    elif 'cifar100' in task:
        n_subset = int(task.split('-')[1])  # task = "cifar100-20" -> n_subset = 20
        train_loader_list, test_loader_list, dset_train_list = process_cifar100(n_subset)
        task_names = ['cifar100-'+str(i+1) for i in range(n_subset)]

### Hyperparameters

In [None]:
# Hyperparameters
lr = 0.005
epochs = 10
save_result = True
meta = 2
archi = [784] + [] + [10]

#init = "normal"
init_width = 0.1
decay = 0
gamma = 1
norm = 'batch'

## Define Model

In [None]:
model = ConvNet(width = init_width).to(device)

In [None]:
# Data collect initialisation
data = {}
data['net'] = 'BNN'
arch = 'STANDARD'

data['arch'] = arch
data['norm'] = norm
data['lr'], data['meta'], data['task_order'] = [], [], []  
data['tsk'], data['epoch'], data['acc_tr'], data['loss_tr'] = [], [], [], []

for i in range(len(test_loader_list)):
    data['acc_test_tsk_'+str(i+1)], data['loss_test_tsk_'+str(i+1)] = [], []

name = '_'+data['net']+'_'+data['arch']+'_'

for t in range(len(task_names)):
    if ('cifar100' in task_names[t]) and ('cifar100' in name):
        pass
    else:
        name = name+task_names[t]+'-'

bn_states = []

lrs = [lr*(gamma**(-i)) for i in range(len(train_loader_list))]

## Train and Test Functions


In [None]:
def train(model, train_loader, current_task_index, optimizer, device,
          prev_cons=None, prev_params=None, path_integ=None, criterion = torch.nn.CrossEntropyLoss()):
    
    model.train()

    for data, target in train_loader:
        if torch.cuda.is_available():
            data, target = data.to(device), target.to(device)
            
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        total_loss = loss        

        total_loss.backward()
        
        # This loop is for BNN parameters having 'org' attribute
        for p in list(model.parameters()): # blocking weights with org value greater than a threshold by setting grad to 0 
            if hasattr(p,'org'):
                p.data.copy_(p.org)
                
        optimizer.step()

        # This loop is only for BNN parameters as they have 'org' attribute
        for p in list(model.parameters()):  # updating the org attribute
            if hasattr(p,'org'):
                p.org.copy_(p.data)

In [None]:
def test(model, test_loader, device, task_idx, criterion = torch.nn.CrossEntropyLoss(reduction='sum'), verbose = False):
    
    model.eval()
    test_loss = 0
    correct = 0
    
    for data, target in test_loader:
        if torch.cuda.is_available():
            data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += criterion(output, target).item() # mean batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    test_acc = round( 100. * float(correct) / len(test_loader.dataset)  , 2)
    
    if len(test_loader.dataset)==60000:
      print(f"Task_idx: {task_idx}")
      print('Train accuracy: {}/{} ({:.2f}%)'.format(
          correct, len(test_loader.dataset),
          test_acc))
    else:
      print(f"Task_idx: {task_idx}")
      print('Test accuracy: {}/{} ({:.2f}%)'.format(
          correct, len(test_loader.dataset),
          test_acc))
      
    return test_acc, test_loss

## Run Model

In [None]:
# Modified code from Laborieux, et al.

for task_idx, task in enumerate(train_loader_list):
    optimizer = Adam_Metaplastic(model.parameters(), lr = lrs[task_idx], meta = meta, weight_decay = decay)
           
    for epoch in range(1, epochs+1):
      train(model, task, task_idx, optimizer, device)

      data['task_order'].append(task_idx+1)
      data['tsk'].append(task_names[task_idx])
      data['epoch'].append(epoch)
      data['lr'].append(optimizer.param_groups[0]['lr'])
      
      print(f"EPOCH: {epoch}")
      train_accuracy, train_loss = test(model, task, device, task_idx,verbose=True)
      
      data['acc_tr'].append(train_accuracy)
      data['loss_tr'].append(train_loss)
      data['meta'].append(meta)
      #current_bn_state = model.save_bn_states()
  
      for other_task_idx, other_task in enumerate(test_loader_list):
        test_accuracy, test_loss = test(model, other_task, device, other_task_idx, verbose=(other_task_idx==task_idx))

        data['acc_test_tsk_'+str(other_task_idx+1)].append(test_accuracy)
        data['loss_test_tsk_'+str(other_task_idx+1)].append(test_loss)


### Output to JSON File

In [None]:
# load json module
print(data)
import json

# create json object from dictionary
filename = 'pFMNIST_2'
file_name = open(str(filename), "w")
json = json.dump(data, file_name)
file_name.close()

## Plot Results

In [None]:
def plot_acc(filename):
  f = open(filename)
  plot_data = json.load(f)
  
  figs, ax = plt.subplots(figsize=(15,8))
  ax.plot(plot_data['acc_test_tsk_1'], label = 'Task 1: pFMNIST')
  ax.plot(plot_data['acc_test_tsk_2'], label = 'Task 2: pFMNIST')
  ax.plot(plot_data['acc_test_tsk_3'], label = 'Task 3: pFMNIST')
  ax.plot(plot_data['acc_test_tsk_4'], label = 'Task 4: pFMNIST')
  ax.plot(plot_data['acc_test_tsk_5'], label = 'Task 5: pFMNIST')
  # ax.plot(plot_data['acc_test_tsk_6'], label = 'Task 6: pUSPS')
  ax.legend(loc = 'best')
  ax.set_ylabel('Test Accuracy', size=15)
  ax.set_xlabel('Epoch', size=15)
  ax.set_title('Test Accuracy of Sequentially Trained pFMNIST, m=0', size=15)

# Function to Compute Shannon Entropy of Dataset

Must first define class labels to pass in, and must import datasets from Keras or external source.

In [None]:
# Load in Datasets
(mnist_train_x, mnist_train_y), (mnist_test_x, mnist_test_y) = tf.keras.datasets.mnist.load_data()
(fmnist_train_x, fmnist_train_y), (fmnist_test_x, fmnist_test_y) = tf.keras.datasets.fashion_mnist.load_data()
(c10_train_x, c10_train_y), (c10_test_x, c10_test_y) = tf.keras.datasets.cifar10.load_data()

# Define class labels for each dataset
mnist_labels = {x: str(x) for x in range(10)}

fmnist_labels = {0:	'T-shirt/top',
          1:	'Trouser',
          2:	'Pullover',
          3: 'Dress',
          4:	'Coat',
          5:	'Sandal',
          6:	'Shirt',
          7:	'Sneaker',
          8:	'Bag',
          9:	'Ankle boot'
}

cifar10_labels = {0:	'airplane',
          1:	'automobile',
          2:	'bird',
          3: 'cat',
          4:	'deer',
          5:	'dog',
          6:	'frog',
          7:	'horse',
          8:	'ship',
          9:	'truck'
}

In [None]:
from skimage.measure import shannon_entropy
def calculate_class_entropies(x, y, class_labels):
  """ Given images (x) and labels (y), calculate shannon entropies for each class.
  Returns class_entropies, a DataFrame of classes and their entropies.
  """

  class_entropies = pd.DataFrame.from_dict(class_labels, orient='index')
  class_entropies = class_entropies.rename(columns={0: "class"})

  entropies = []
  for idx, class_labels in class_labels.items():
    images = x[np.where(y==idx)[0]]
    entropy = np.mean([shannon_entropy(img) for img in images])
    entropies.append(entropy)

  class_entropies['entropy'] = entropies
  return class_entropies