## Importing Libraries

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchsummary import summary
%matplotlib inline
import matplotlib.pyplot as plt
import random
import warnings
warnings.filterwarnings('ignore')
from torch.autograd import Variable
from copy import deepcopy
from tqdm import tqdm

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

In [None]:
def mnist_imshow(img):
    plt.imshow(img.reshape([28,28]), cmap="gray")
    plt.axis('off')
    plt.show()

In [None]:
#@title Experiment Constants
epochs = 2
epochs = 50
lr = 1e-3
batch_size = 128
sample_size = 200
hidden_size = 200
num_task = 3 # Number of tasks, task 1 we would not have any permutation
batch_size = 128

## Dataset and Dataloader

In [None]:
class PermutedMNIST(datasets.MNIST):

    def __init__(self, root="~/.torch/data/mnist", train=True, permute_idx=None):
        super(PermutedMNIST, self).__init__(root, train, download=True)
        assert len(permute_idx) == 28 * 28
        if self.train:
          # print("data", type(self.train_data),type(self.train_data[0]),self.train_data.shape, self.train_data[0].shape, self.train_data[0])
          self.training_data= torch.stack([img.float().view(-1)[permute_idx] / 255 for img in self.train_data])
          self.training_labels = self.train_labels
        else:
          self.testing_data = torch.stack([img.float().view(-1)[permute_idx] / 255 for img in self.test_data])
          self.testing_labels = self.test_labels

    def __getitem__(self, index):

        if self.train:
            img, target = self.training_data[index], self.training_labels[index]
        else:
            img, target = self.testing_data[index], self.testing_labels[index]

        return img, target

    def get_sample(self, sample_size):
        sample_idx = random.sample(range(len(self)), sample_size)
        return [img for img in self.training_data[sample_idx]]

In [None]:
def get_permute_mnist():
    train_loader = {}
    test_loader = {}
    idx = list(range(28 * 28)) # first time there is no shuffle
    for i in range(num_task):
        train_loader[i] = torch.utils.data.DataLoader(PermutedMNIST(train=True, permute_idx=idx),
                                                      batch_size=batch_size,
                                                      num_workers=4)
        test_loader[i] = torch.utils.data.DataLoader(PermutedMNIST(train=False, permute_idx=idx),
                                                     batch_size=batch_size)
        print(f'Index for  task {i} \n {idx}')
        idx = random.sample(idx, len(idx))
    return train_loader, test_loader 


train_loader, test_loader = get_permute_mnist()

In [None]:
t = 0
for i in range(num_task):
  print(f' ------task {i} --------')
  batch = next(iter(train_loader[i]))
  if t == 0:
    t = random.randint(1,len(batch))
  image, label = batch[0][t], batch[1][t]
  print (batch[0][t].shape)
  print("label : ", label.item())  
  mnist_imshow(image)


## Network
A simple network with only Linear layers

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_size=200):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, 10)

    def forward(self, input):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return x

model = MLP().to(device)
summary(model, input_size=(1,784))

## Normal Train Function

Variable Function - Used for Gradient [Reference](https://www.geeksforgeeks.org/variables-and-autograd-in-pytorch/)

In [None]:
def variable(t: torch.Tensor, use_cuda=True, **kwargs):
  """
  Takes in a tensor and converts into a tensor with gradient
  We would need gradients of parameters, hence needed 
  """
  if torch.cuda.is_available() and use_cuda:
    t = t.cuda()
  return Variable(t, **kwargs)

In [None]:
def normal_train(model: nn.Module, 
                 optimizer: torch.optim, 
                 data_loader: torch.utils.data.DataLoader):
  """
  Takes in a model architecture, trains its and returns average epoch loss
  """
  model.train()
  epoch_loss = 0
  for input, target in data_loader:
      input, target = variable(input), variable(target)
      optimizer.zero_grad()
      output = model(input)
      loss = F.cross_entropy(output, target)
      epoch_loss += loss.item()
      loss.backward()
      optimizer.step()
  return epoch_loss / len(data_loader)

## Normal Test Function

In [None]:
def test(model: nn.Module, 
         data_loader: torch.utils.data.DataLoader):
    model.eval()
    correct = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        output = model(input)
        correct += (F.softmax(output, dim=1).max(dim=1)[1] == target).data.sum()
    return correct / len(data_loader.dataset)

## Elastic Weight Consolidation

In [None]:
class EWC(object):
    def __init__(self, 
                 model: nn.Module, 
                 dataset: list):

        self.model = model
        self.dataset = dataset

        # default model params
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

        for n, p in deepcopy(self.params).items():
            self._means[n] = variable(p.data)

    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = variable(p.data)

        self.model.eval() # we need to do one round of back propogation to understand the gradients of params
        # interestingly we dont have to always use model.train() to get the gradients
        for input in self.dataset:
            self.model.zero_grad()
            input = variable(input)
            output = self.model(input).view(1, -1)
            label = output.max(1)[1].view(-1)
            loss = F.nll_loss(F.log_softmax(output, dim=1), label)
            loss.backward()

            for n, p in self.model.named_parameters():
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
            loss += _loss.sum()
        return loss

### EWC Train Function
This has got Penalties

In [None]:
def ewc_train(model: nn.Module, 
              optimizer: torch.optim, 
              data_loader: torch.utils.data.DataLoader,
              ewc: EWC, 
              importance: float): # importance is a hyperparam 
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        ##-------Updated Loss Function-------###
        loss = F.cross_entropy(output, target) + importance * ewc.penalty(model)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)


## Plot Functions

In [None]:
def loss_plot(x):
    for t, v in x.items():
        plt.plot(list(range(t * epochs, (t + 1) * epochs)), v, label = "Task "+ str(t))
    plt.legend(loc = 'upper right')
    plt.title("Loss Plot")
    plt.show()

def accuracy_plot(x):
    for t, v in x.items():
        plt.plot(list(range(t * epochs, num_task * epochs)), v, label = "Task "+  str(t))
    plt.ylim(0, 1)
    plt.legend(loc = 'upper right')
    plt.title("Accuracy Plot")
    plt.show()

## Standard Process Execution

In [None]:
def standard_process(epochs, use_cuda=True, weight=True):
    model = MLP(hidden_size)
    if torch.cuda.is_available() and use_cuda:
        model.cuda()
    optimizer = optim.SGD(params=model.parameters(), lr=lr)

    loss, acc = {}, {} # across tasks 
    for task in range(num_task):
        loss[task] = [] # per task loss
        acc[task] = [] # per task accuracy
        for _ in tqdm(range(epochs)):
            loss[task].append(normal_train(model, optimizer, train_loader[task]))
            for sub_task in range(task + 1):
                acc[sub_task].append(test(model, test_loader[sub_task]))
        if task == 0 and weight:
            weight = model.state_dict()
    return loss, acc, weight

In [None]:
#@title Training and Output of Standard Process
loss, acc, weight = standard_process(epochs) # The weights are used by EWC process

In [None]:
loss_plot(loss), accuracy_plot(acc)

In [None]:
(weight)

## EWC Process Execution

In [None]:
def ewc_process(epochs, importance, use_cuda=True, weight=None):
    model = MLP(hidden_size)
    if torch.cuda.is_available() and use_cuda:
        model.cuda()
    optimizer = optim.SGD(params=model.parameters(), lr=lr)

    loss, acc, ewc = {}, {}, {}
    for task in range(num_task):
        loss[task] = []
        acc[task] = []

        if task == 0:
            if weight:
                model.load_state_dict(weight)
            else:
                for _ in tqdm(range(epochs)):
                    loss[task].append(normal_train(model, optimizer, train_loader[task]))
                    acc[task].append(test(model, test_loader[task]))
        else:
            old_tasks = []
            for sub_task in range(task):
                old_tasks = old_tasks + train_loader[sub_task].dataset.get_sample(sample_size)
            old_tasks = random.sample(old_tasks, k=sample_size)
            for _ in tqdm(range(epochs)):
                loss[task].append(ewc_train(model, optimizer, train_loader[task], EWC(model, old_tasks), importance))
                for sub_task in range(task + 1):
                    acc[sub_task].append(test(model, test_loader[sub_task]))

    return loss, acc

In [None]:
#@title Training and Output of EWC Process
loss_ewc, acc_ewc = ewc_process(epochs, importance=1000)

In [None]:
loss_plot(loss_ewc) , accuracy_plot(acc_ewc)

In [None]:
plt.plot(acc[0], label="sgd")
plt.plot(acc_ewc[0], label="ewc")
plt.legend()