In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms
from sklearn.utils import shuffle
from copy import deepcopy
import time

In [None]:
def get_data(datapath : str):
  
  pass
  return data, tasknum, ncla

def train_epoch():
  pass

def train_model():
  pass

def eval_model():
  pass

def cal_FIM(old_models, cur_model)->float:
  pass


# Network

In [None]:
class Network(nn.Module):
    def __init__(self, inputsize, taskcla, hiddensize=400, split=False):
        """
        inputsize : [n_channels, height, width]
        taskcla : [(taskid, n_cla),...]
        split : True if using split experiments
        hiddensize : the number of unit in hidden layers
        """
        super(Network, self).__init__()
        ncha, size, _ = inputsize
        self.taskcla = taskcla
        self.split = split
        self.fc1 = nn.Linear(ncha*size*size, hiddensize)
        self.fc2 = nn.Linear(hiddensize, hiddensize)
        self.drop = nn.Dropout(0.5)
        if split:
            self.last = nn.ModuleList()
            for t, n_cla in taskcla:
                self.last.append(nn.Linear(hiddensize, n_cla))
        else:   # not split experiments
            self.last = nn.Linear(hiddensize, taskcla[0][1])
        
    def forward(self, x):
        h = x.view(x.size(0), -1)   # flatten convert image matrix into vector
        h = self.drop(F.relu(self.fc1(h)))
        h = self.drop(F.relu(self.fc2(h)))
        if self.split:
            y = []
            for t, _ in self.taskcla:
                y.append(self.last[t](h))
        else:
            y = self.last(h)
        return y


# Method

In [None]:
class EWC():

    def __init__(self, model, nepochs=100, sbatch=256, lr=0.001, lamb=100, split=False):
        self.model = model
        self.model_old = model
        self.fisher = {}
        self.params_opt = {}
        self.nepochs = nepochs
        self.sbatch = sbatch        
        self.lr = lr
        self.split = split
        self.lamb = lamb

        self.ce = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'

    def train(self, taskid, xtrain, ytrain, data):

        for e in range(self.nepochs):
            clock0 = time.time()
            self.train_epoch(taskid, xtrain, ytrain)
            clock1 = time.time()
            train_loss, train_acc = self.eval(taskid, xtrain, ytrain)
            clock2 = time.time()
            if (e % 5 == 0):
                print('| Epoch {:3d}, train_time={:5.1f}ms, eval_time={:5.1f}ms | \
                        Train: loss={:.3f}, acc={:5.1f}% |'.format(e+1,1000*(clock1-clock0),
                                                                   1000*(clock2-clock1),train_loss,100*train_acc))
        
        self.params_opt[taskid] = {}
        for name, param in self.model.named_parameters():
            self.params_opt[taskid][name] = param.data.clone()
        self.fisher[taskid] = self.cal_Fisher(taskid, xtrain, ytrain)

        return

    def train_epoch(self, taskid, xtrain, ytrain):
        self.model.train()
        r = np.arange(xtrain.size(0))
        np.random.shuffle(r)
        r = torch.LongTensor(r).to(self.device)

        for i in range(0, len(r), self.sbatch):
            if i + self.sbatch <= len(r):
                b = r[i:i+self.sbatch]
            else:
                b = r[i:]
            images = xtrain[b]
            targets = ytrain[b]

            if self.split:
                outputs = self.model.forward(images)[t]
            else:
                outputs = self.model.forward(images)
            
            loss = self.criterion(taskid, outputs, targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return

    def eval(self, taskid, x, y):
        total_loss = 0
        total_acc = 0
        # total_num = 0
        self.model.eval()

        r = np.arange(x.size(0))
        r = torch.LongTensor(r).to(self.device)
        for i in range(0, len(r), self.sbatch):
            if i + self.sbatch <= len(r):
                b = r[i:i+self.sbatch]
            else:
                b = r[i:]
            images = x[b]
            targets = y[b]

            if self.split:
                outputs = self.model.forward(images)[taskid]
            else:
                outputs = self.model.forward(images)
            
            _, pred = outputs.max(1)    # pred is index of max value 
            hits = (pred==targets).float()
            loss = self.criterion(taskid, outputs, targets)
            
            total_loss += loss.data.cpu().numpy()*len(b)
            total_acc += hits.sum().data.cpu().numpy()
            # total_num += len(b)
        
        return total_loss/len(r), total_acc/len(r)


    def criterion(self, taskid, outputs, targets):
        loss_reg = 0
        if taskid == 0:
            return self.ce(outputs, targets)
        for t in range(taskid):
            for name, param in self.model.named_parameters():
                fisher = self.fisher[t][name]
                opt_param = self.params_opt[t][name]
                loss_reg += torch.sum(fisher*(param.data - opt_param).pow(2))/2
        return self.ce(outputs, targets) + self.lamb*loss_reg
    
    def cal_Fisher(self, taskid, xtrain, ytrain):
        """
        return fisher[name_param]
        """
        print("Calculating Fisher Information Matrix ...")
        fisher = {}
        for name, param in self.model.named_parameters():
            fisher[name] = 0*param.data
        #Note only with model.train() we can compute grad
        self.model.train()

        r = np.arange(xtrain.size(0))
        np.random.shuffle(r)
        r = torch.LongTensor(r).to(self.device)

        for i in range(0, len(r), self.sbatch):
            if i + self.sbatch <= len(r):
                b = r[i:i+self.sbatch]
            else:
                b = r[i:]
            images = xtrain[b]
            targets = ytrain[b]

            self.model.zero_grad()
            if self.split:
                outputs = self.model.forward(images)[taskid]
            else:
                outputs = self.model.forward(images)
            
            loss = self.criterion(taskid, outputs, targets)
            loss.backward()

            # Get fisher through gradient of param
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    # Instead of Calculating Hessian, FIM approximates it
                    # by computing first order derivations                                        
                    fisher[name] += param.grad.data.pow(2)*self.sbatch  
            
        # average fisher over all data in this task
        with torch.no_grad():
            for name, param in self.model.named_parameters():     
                fisher[name] /= xtrain.size(0)
        print("Calculated Fisher Information Matrix.")
        return fisher


            



# Data

In [None]:
def get_data(datapath, tasknum=10, seed=0):
    np.random.seed(seed)
    data = {}
    taskcla = []
    size = [1, 28, 28]
    mean = torch.Tensor([0.1307]).cpu()
    std = torch.Tensor([0.3081]).cpu()
    dat = {}
    dat['train'] = datasets.MNIST(datapath, train=True, download=True)
    dat['test'] = datasets.MNIST(datapath, train=False, download=True)

    for i in range(tasknum):
        sys.stdout.flush()
        data[i] = {}
        data[i]['name'] = 'pmnist-{:d}'.format(i)
        data[i]['ncla'] = 10
        permutation = np.random.permutation(28*28)
        for s in ['train', 'test']:
            if s == 'train':
                arr = dat[s].train_data.view(dat[s].train_data.shape[0], -1).float()
                label = torch.LongTensor(dat[s].train_labels)
            else:
                arr = dat[s].test_data.view(dat[s].test_data.shape[0], -1).float()
                label = torch.LongTensor(dat[s].test_labels)
            
            # print(arr[0])
            # print(mean)

            arr = (arr/255 - mean)/std

            data[i][s] = {}
            data[i][s]['x'] = arr[:, permutation].view(-1, size[0], size[1], size[2])
            data[i][s]['y'] = label

    n = 0
    for t in range(tasknum):
        taskcla.append((t, data[t]['ncla']))
        n += data[t]['ncla']
    data['ncla'] = n

    return data, taskcla, size


# Main

In [None]:
def main():
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    datapath = '../../dat/'
    print("Loading data...")
    data, taskcla, inputsize = get_data(datapath, tasknum=10)
    print('Input size =', inputsize, '\nTask info =', taskcla)
    # Inits
    print('Inits...')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    net = Network(inputsize, taskcla, hiddensize=400, split=False)
    appr = EWC(net, nepochs=100, sbatch=256, lr=0.001, lamb=400, split=False)
    
    acc = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    loss = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    for i in range(len(taskcla)):
        print('*' * 100)
        print('Task {:2d} ({:s})'.format(i, data[i]['name']))
        print('*' * 100)

        xtrain = data[i]['train']['x'].to(device)
        ytrain = data[i]['train']['y'].to(device)

        # xtest = data[i]['test']['x'].to(device)
        # ytest = data[i]['test']['y'].to(device)

        appr.train(taskid=i, xtrain=xtrain, ytrain=ytrain, data=data)
        print('-' * 100)

        # Test
        for j in range(i+1):
            xtest = data[j]['test']['x'].to(device)
            ytest = data[j]['test']['y'].to(device)
            loss[i,j], acc[i,j] = appr.eval(taskid=j, x=xtest, y=ytest)
            print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<'.format(j, data[j]['name'], loss[i,j],
                                                                                      100 * acc[i,j]))
    
    # Done
    print('*' * 100)
    print('Accuracies =')
    for i in range(acc.shape[0]):
        print('\t', end='')
        for j in range(acc.shape[1]):
            print('{:5.1f}% '.format(100 * acc[i, j]), end='')
        print()
    print('*' * 100)
    
    print()
    avg_acc = np.mean(acc[acc.shape[0]-1,:])
    print ('ACC: {:5.4f}%'.format(avg_acc))
    print()
    
    ucb_bwt = (acc[-1] - np.diag(acc)).mean()
    print ('BWT : {:5.2f}%'.format(ucb_bwt))

    print('Done!')

    return acc, loss




# Run

In [None]:
acc, loss = main()

Loading data...




Input size = [1, 28, 28] 
Task info = [(0, 10), (1, 10), (2, 10), (3, 10), (4, 10), (5, 10), (6, 10), (7, 10), (8, 10), (9, 10)]
Inits...
****************************************************************************************************
Task  0 (pmnist-0)
****************************************************************************************************
| Epoch   1, train_time=421.5ms, eval_time=265.0ms |                         Train: loss=2.234, acc= 38.0% |
| Epoch   6, train_time=401.4ms, eval_time=222.3ms |                         Train: loss=1.645, acc= 68.7% |
| Epoch  11, train_time=399.6ms, eval_time=188.9ms |                         Train: loss=0.896, acc= 80.6% |
| Epoch  16, train_time=359.6ms, eval_time=210.8ms |                         Train: loss=0.616, acc= 84.5% |
| Epoch  21, train_time=369.5ms, eval_time=173.9ms |                         Train: loss=0.505, acc= 86.3% |
| Epoch  26, train_time=367.7ms, eval_time=177.6ms |                         Train: loss=0.446, 

NameError: ignored