In [None]:
import sys
import time
import numpy as np
import logging
import torch
import torchvision
import os
import torchvision.transforms as transforms

# from dataload import create_dataloader, data_create  # , get_split_cifar100
sys.path.append(os.path.abspath("dataload.ipynb"))

In [None]:
'''Code for training PMNIST CL tasks - Pytorch has problems with gradient hooks, so the weights are manually compared with previous task
Do not use for training time calculation'''
import sys
import time
import numpy as np
import logging
# Configure logger
# logging.basicConfig(filename="test.log", format='%(filename)s: %(message)s', filemode='w')

# Setting threshold level
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
logging.info("Caution: This is the root logger!")

import torch
import torchvision
import torchvision.transforms as transforms

sys.path.append("..")

from model import (ClassifierMLP, _prune)

from dataload import create_dataloader, data_create  # , get_split_cifar100
from res18 import AResNet18, CNet, ResNet18
from util_res import (all_task_eval, change_weights, check_model_version, dataset_eval,
                      learnable_weights, load_and_check, turn_off_adjx)

gpu_boole = torch.cuda.is_available()

class Myclass:
  name='mlp_exp1'
  batch_size=128
  dataroot='../data/'
  dataset='pmnist'
  model='mlp'
  n_class=10
  tasks=4
  im_size=28
  cn=1
  hidden_size=100
  epochs=30
  lr=1e-3
  lr_adj=1e-2
  decay=0
  loss='ce'
  optim='adam'
  train_p=95.
  wt_para=0.22
  prune_para=0.999999
  prune_epoch=30
  load_model=True
  restart_tsk=0
  model_pth=''
  turn_off=True
  savename='pmnist_task_'
  print_ev=10
  workers=4




args=Myclass()


for tsktodo in range(0,args.tasks-1):
  args.restart_tsk=tsktodo+1
  args.model_pth="mlp_exp1pmnist_task_"+str(tsktodo)+".pt"
  if args.model == 'ares18':
      net = AResNet18(num_classes=args.n_class, tasks=args.tasks+1, channels=args.cn)
  elif args.model == 'res18':
      net = ResNet18(num_classes=args.n_class, channels=args.cn)
  elif args.model == 'cnet':
      net = CNet(num_classes=args.n_class, channels=args.cn, tasks=args.tasks)
  elif args.model == 'mlp':
      net = ClassifierMLP(image_size = args.im_size, output_shape=args.n_class, tasks=args.tasks+2, \
          layer_size=args.hidden_size, bn_boole=True)
  else:
      raise Exception("model architecture not found")
  # print(list(net.named_buffers()))

  if args.loss == 'ce':
      loss_metric = torch.nn.CrossEntropyLoss()

  if gpu_boole:
      net = net.cuda()

  if args.load_model:
      ### load from saved_model and do the checks, etc. ####
      net = load_and_check(args, net, loss_metric)
      # args.lr=args.lr2

  if args.lr_adj == -1:
      args.lr_adj = args.lr
  pruned = False

  if args.optim == 'adam':
      optimizer = torch.optim.Adam([
                      {'params': (param for name, param in net.named_parameters() if 'adjx' not in name), 'lr':args.lr,'momentum':0},
                      {'params': (param for name, param in net.named_parameters() if 'adjx' in name), 'lr':args.lr_adj,'momentum':0,'weight_decay':args.decay}
                  ])
  elif args.optim == 'sgd':
      optimizer = torch.optim.SGD([
                  {'params': (param for name, param in net.named_parameters() if 'adjx' not in name), 'lr':args.lr,'momentum':0.9,'weight_decay':5e-4},
                  {'params': (param for name, param in net.named_parameters() if 'adjx' in name), 'lr':args.lr_adj,'momentum':0,'weight_decay':args.decay}
              ])
  else:
      raise Exception("Unknown optimizer error.")

  j = args.restart_tsk
  print(j)
  if j>0:
      """Hooks are unstable - this is a slow alternative"""
      free_wts = learnable_weights(net, args, verbose=True)
      logging.info("\n Parameters that are still trainable = {} %".format(free_wts*100))

      for ix in range(j):
          """Checking whether old adjx and fixed weights are preserved"""
          error_signal = check_model_version(net,old_path='./{}{}{}.pt'.format(args.name,args.savename,j-1)\
              ,task=ix)
          assert(error_signal==False)

  if args.dataset == 'pmnist':
      if args.restart_tsk==0:
          permutations = [torch.Tensor(np.random.permutation(784).astype(np.float64)).long() for _ in range(args.tasks+1)]
          torch.save(torch.stack(permutations),'permutations.pt')
      else:
          permutations = torch.load('./permutations.pt')
      logging.info("Loading {} dataset".format(args.dataset))
      training, testing = data_create(args)
      train_loader, test_loader = create_dataloader(training, testing, args)
  elif args.dataset == 's-cifar100' or args.dataset == 'cifar100': #args.dataroot changes
      logging.info("Loading {} dataset".format(args.dataset))
      train_loader, test_loader = get_split_cifar100(args, j+1, class_size=args.n_class)
  else:
      raise Exception("Dataset {} Error - not available".format(args.dataset))


  EPOCHS = list(range(args.epochs))
  for epoch in EPOCHS:

      t1 = time.time()

      print("Task:",j,"- Epoch:",epoch)
      if j>0:
          if logging.getLogger().getEffectiveLevel()<=20:
              print("Test acc for all previous tasks:")
              test_acc, test_loss = all_task_eval(net, task=j-1, loss_metric=loss_metric, args=args)

          for ix in range(j):
              net = turn_off_adjx(net, ix, bn_off=True)

      for i, (x,y) in enumerate(train_loader):
          optimizer.zero_grad()

          # if j>0:
          #     for iz, inx in enumerate(range(j*args.n_class, (j+1)*args.n_class)):
          #         y[y==inx]=iz
          if gpu_boole:
              x, y = x.cuda(), y.cuda()
          if args.dataset=='pmnist':
              x = x.view(-1,28*28)[:,permutations[j]]
              y = y.view(-1)

          outputs = net(x,task=j)
          # outputs = net(x)
          # print(outputs.sum())

          loss = loss_metric(outputs,y)

          loss.backward()
          optimizer.step()
          if j>0:
              # manually keeping weights same for old adjx - needs to be implemented through hooks
              net=change_weights(net, old_model='./{}{}{}.pt'.format(args.name,args.savename,j-1), present_task=j, path=True)

          del loss; del x; del y; del outputs

      ###### Epoch Stats Printing ###########
      if (epoch)%args.print_ev==0 or (args.epochs-epoch)<10:
          train_acc, train_loss = dataset_eval(net, train_loader, loss_=loss_metric, args=args, print_txt='Training', verbose = 1, task = j)
          if logging.getLogger().getEffectiveLevel() <=20:
              test_acc, test_loss= dataset_eval(net, test_loader, loss_=loss_metric, args=args, print_txt='Testing', verbose = 1, task = j)
          test_acc_true, test_loss_true= dataset_eval(net, test_loader, loss_=loss_metric, args=args, \
              print_txt='Testing(Binary Mapping)', verbose = 1, task = j, round_=True)
          t2 = time.time()
          logging.debug('Time left for task:',((t2-t1)/60)*(args.epochs-epoch),'minutes')
          print()
      ######################################

      ############## Pruning ##############
      if (epoch >= (args.epochs - args.prune_epoch) and args.epochs>args.prune_epoch):
          logging.info("Relevance Mapping Pruning...")
          # _prune(net,task=j,prune_para=args.prune_para)
          _prune(net,task=j,prune_para=args.prune_para, wt_sp=True, wt_para=args.wt_para)
          pruned=True
          if (epoch)%args.print_ev==0 or (args.epochs-epoch)<10:
              if logging.getLogger().getEffectiveLevel()<=20:
                  test_acc, test_loss= dataset_eval(net, test_loader, loss_=loss_metric, args=args, \
                      print_txt='Testing after mapping pruning', verbose = 1, task = j)
                  test_acc_true, test_loss_true= dataset_eval(net, test_loader, loss_=loss_metric, args=args, \
                      print_txt='Testing(Binary) post-prune', verbose = 1, task = j, round_=True)
          # net = turn_off_adjx(net, j, bn_off=True)
          # change this accuracy for other tasks ##################
          if args.train_p != -1:
              if epoch==(args.epochs-11) and epoch<210: # add a argparse for this epoch limit
                  if train_acc < args.train_p and args.lr>1e-5: # this value changes for different experiments/datasets/models
                      logging.info("changing LR and adding a few epochs")
                      EPOCHS += list(range(args.epochs, args.epochs+29))
                      args.prune_epoch += 29
                      args.epochs += 29

                      args.lr /= 10.
                      for ip, g in enumerate(optimizer.param_groups):
                          if ip==0: # first parameter group (non adjx)
                              g['lr'] = args.lr
                      logging.debug(optimizer)
          print()

  #######################################################
  for ix in range(j+1):
      net = turn_off_adjx(net, ix, bn_off=True) # turns off adjacency and BN for the task by requires_grad=False

  print("--------------------------------")
  print("Test acc for all tasks:")
  if logging.getLogger().getEffectiveLevel()<=20:
      test_acc, test_loss = all_task_eval(net, task=j, loss_metric=loss_metric, args=args)
      logging.info("Test acc, Test loss:",test_acc, test_loss)
  test_acc, test_loss = all_task_eval(net, task=j, loss_metric=loss_metric, args=args, round=True)
  print("Test acc, Test loss: (Binary Mapping)",test_acc, test_loss)
  print("--------------------------------")
  print()
  for ix in range(j+1):
      net = turn_off_adjx(net, ix, bn_off=True)
  # if j == args.tasks-1:
  logging.info("Saving model...")
  torch.save(net,'{}{}{}.pt'.format(args.name,args.savename,j))
  if args.turn_off==True:
      exit(0)


1
Task: 1 - Epoch: 0




Training Accuracy: 80.735 Loss: 0.004758233503003915
Testing(Binary Mapping) Accuracy: 88.3 Loss: 0.003044848125427961

Task: 1 - Epoch: 1
Task: 1 - Epoch: 2
Task: 1 - Epoch: 3
Task: 1 - Epoch: 4
Task: 1 - Epoch: 5
Task: 1 - Epoch: 6
Task: 1 - Epoch: 7
Task: 1 - Epoch: 8
Task: 1 - Epoch: 9
Task: 1 - Epoch: 10
Training Accuracy: 87.6 Loss: 0.0030093983272711437
Testing(Binary Mapping) Accuracy: 92.84 Loss: 0.0017824942228384317

Task: 1 - Epoch: 11
Task: 1 - Epoch: 12
Task: 1 - Epoch: 13
Task: 1 - Epoch: 14
Task: 1 - Epoch: 15
Task: 1 - Epoch: 16
Task: 1 - Epoch: 17
Task: 1 - Epoch: 18
Task: 1 - Epoch: 19
Task: 1 - Epoch: 20
Training Accuracy: 88.78666666666666 Loss: 0.0027609904987116653
Testing(Binary Mapping) Accuracy: 93.75 Loss: 0.0015834198851138353

Task: 1 - Epoch: 21
Training Accuracy: 89.045 Loss: 0.0027077896813551584
Testing(Binary Mapping) Accuracy: 93.63 Loss: 0.0015741609083488584

Task: 1 - Epoch: 22
Training Accuracy: 89.24 Loss: 0.002647690129528443
Testing(Binary Mapp