In [None]:
!rm -r 'IncrementalLeraningMLDL'
!git clone "https://github.com/wAnto97/IncrementalLeraningMLDL"
from IncrementalLeraningMLDL.src.CIFAR100_dataset import MyCIFAR100
from IncrementalLeraningMLDL.src.Utils import Utils
from IncrementalLeraningMLDL.src.MyNet import MyNet
from IncrementalLeraningMLDL.src.Loss import Loss

import numpy as np
import sys
import copy
from torch.backends import cudnn
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from torch.utils.data import  DataLoader

from google.colab import drive
drive.mount('/content/gdrive')

Cloning into 'IncrementalLeraningMLDL'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects:   8% (1/12)[Kremote: Counting objects:  16% (2/12)[Kremote: Counting objects:  25% (3/12)[Kremote: Counting objects:  33% (4/12)[Kremote: Counting objects:  41% (5/12)[Kremote: Counting objects:  50% (6/12)[Kremote: Counting objects:  58% (7/12)[Kremote: Counting objects:  66% (8/12)[Kremote: Counting objects:  75% (9/12)[Kremote: Counting objects:  83% (10/12)[Kremote: Counting objects:  91% (11/12)[Kremote: Counting objects: 100% (12/12)[Kremote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 898 (delta 1), reused 7 (delta 1), pack-reused 886
Receiving objects: 100% (898/898), 10.61 MiB | 21.82 MiB/s, done.
Resolving deltas: 100% (584/584), done.
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent

**Loading data**

In [None]:
# Define transforms for training phase
train_transform = transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize( (0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))]) # Normalizes tensor with mean and standard deviation

# Define transforms for the evaluation phase
eval_transform = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize( (0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))])

training_set = MyCIFAR100('/content',train=True, n_groups=10, transform=train_transform, download=True)
test_set = MyCIFAR100('/content',train=False, n_groups=10, transform=eval_transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


**Hyperparameters**

In [None]:
DEVICE = 'cuda' # 'cuda' or 'cpu'  

BATCH_SIZE = 128     # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing
                     # the batch strain_dataloaderize, learning rate should change by the same factor to have comparable results
LR = 2     # The initial Learning Rate
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 1e-5  # Regularization, you can keep this at the default

NUM_EPOCHS = 70             # Total number of training epochs (iterations over dataset)
STEP_SIZE = [49,63]      # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.2                 # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 10

CLASSES_PER_GROUP=10
NUM_GROUPS=10

**Utils function**

In [None]:
def validation(val_dataloader,net,lwf,conf_matrix=False):
    net.train(False)
    running_corrects = 0
    y_pred = []
    all_labels = []
    tmp_loss=[]
    for images, labels, _ in val_dataloader:

        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        # Forward Pass
        outputs = net(images)
        # Get predictions
        _, preds = torch.max(outputs.data, 1)
        loss,clf_loss,dist_loss = myLoss.icarl_loss([],outputs,labels,step,-1,utils,CLASSES_PER_GROUP) #current_step = -1 to compute the loss on the validation set
        tmp_loss.append(loss.item())
        # Update Corrects
        running_corrects += torch.sum(preds == labels.data).data.item()
        y_pred += list(map(lambda x : x.item(),preds))
        all_labels += list(labels)

        # Calculate Accuracy
    accuracy = running_corrects / float(len(val_dataloader.dataset))

    if(conf_matrix == True):
        all_labels = list(map(lambda label : label.item(),all_labels))
        return accuracy,confusion_matrix(y_pred,np.array(all_labels))

    return accuracy,np.array(tmp_loss).mean()

**Main**

In [None]:
myNet = MyNet(n_classes=CLASSES_PER_GROUP)
utils = Utils()
myLoss = Loss()
typeScheduler = 'plateau'

#Creating dataloader for the first group of 10 classes
train_dataloader,val_dataloader,test_dataloader = utils.create_dataloaders(training_set,test_set,1,BATCH_SIZE)

#Initialize some useful lists
best_train_accuracies = []
best_val_accuracies = []
best_test_accuracies = []
losses_train_all = []
old_outputs=[]

for i in range(NUM_GROUPS):
    best_val_accuracy = -1

    step=i+1
    print("STARTING LwF TRAINING WITH GROUP:\t",step)  
    
    if step > 1:
      n_old_classes = CLASSES_PER_GROUP*(step-1)
      train_dataloader,val_dataloader,test_dataloader = utils.create_dataloaders(training_set,test_set,step,BATCH_SIZE)
      myNet.update_network(best_net,CLASSES_PER_GROUP + n_old_classes,myNet.init_weights)
      
    optimizer,scheduler = myNet.prepare_training(LR,MOMENTUM,WEIGHT_DECAY,STEP_SIZE,GAMMA,typeScheduler=typeScheduler)

    losses_train = []
    losses_val = []
    val_accuracies = []
    train_accuracies = []
    classification_losses = []
    distillation_losses = []

    myNet.net = myNet.net.to(DEVICE)
    cudnn.benchmark 

    for epoch in range(NUM_EPOCHS):
        running_correct_train = 0
        if typeScheduler == 'multistep':
          print('Starting epoch {}/{}, LR = {}'.format(epoch+1, NUM_EPOCHS, scheduler.get_last_lr()))
        elif typeScheduler == 'plateau':
          print('Starting epoch {}/{}, LR = {}'.format(epoch+1, NUM_EPOCHS, optimizer.param_groups[0]['lr']))

        myNet.net.train() # Set Network to train mode
        current_step = 0
        losses_tmp = []
        for images, labels, _ in train_dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            #Set all gradients to zero
            optimizer.zero_grad() 

            #Computing output and creating the acyclic graph for updating the gradients
            outputs = myNet.net(images)

            #Computing predictions
            _, preds = torch.max(outputs.data, 1)
            
            if(step > 1):
                old_outputs = myNet.get_old_outputs(images,labels)

            loss,clf_loss,dist_loss = myLoss.icarl_loss(old_outputs,outputs,labels,step,current_step,utils,CLASSES_PER_GROUP)
            classification_losses.append(clf_loss.item())
            distillation_losses.append(dist_loss.item())
            losses_tmp.append(loss.item())

            #Calculate correct predictions
            running_correct_train += torch.sum(preds == labels.data).data.item()

            #Accumulate gradients
            loss.backward()

            # Update weights based on accumulated gradients  
            optimizer.step() 

            current_step += 1

        #Calculate training accuracy
        train_accuracy = running_correct_train/len(train_dataloader.dataset)

        #Validate the model
        val_accuracy,val_loss = validation(val_dataloader,myNet.net,myLoss)

        print("Accuracy on the training :\t",train_accuracy)
        print("Accuracy on the validation :\t",val_accuracy)

        #Save the net which minimizes the accuracy on the validation
        if val_accuracy > best_val_accuracy:
            best_net = copy.deepcopy(myNet.net)
            best_train_accuracy = train_accuracy
            best_val_accuracy = val_accuracy
        
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        losses_train.append(np.array(losses_tmp).mean())

        # Step the scheduler
        if typeScheduler == 'multistep':
            scheduler.step()
        #Reduce the learning rate when the validation loss reaches a plateau
        elif typeScheduler == 'plateau':
            scheduler.step(val_loss)

    # Save accuracies and losses            
    print("Best accuracy on the training :\t",best_train_accuracy)
    print("Best accuracy on the validation :\t",best_val_accuracy)
    losses_train_all.append(losses_train)
    best_train_accuracies.append(best_train_accuracy)
    best_val_accuracies.append(best_val_accuracy)

    #Test 
    test_accuracy,test_matrix = validation(test_dataloader,best_net,myLoss,conf_matrix=True)
    print("Accuracy on the test :\t",test_accuracy)
    best_test_accuracies.append(test_accuracy)
    
    utils.writeOnFileMetrics('LwFMetrics_icarlLoss_plateau.json', step, [best_train_accuracy,best_val_accuracy,test_accuracy,test_matrix.tolist()])
    utils.writeOnFileLosses('LwFLosses_icarlLoss_plateau.json', step, [classification_losses,distillation_losses])
    !cp  './LwFMetrics_icarlLoss_plateau.json' './gdrive/My Drive/LwFMetrics_icarlLoss_plateau.json'
    !cp  './LwFLosses_icarlLoss_plateau.json' './gdrive/My Drive/LwFLosses_icarlLoss_plateau.json'

STARTING LwF TRAINING WITH GROUP:	 1
Starting epoch 1/70, LR = 2
Accuracy on the training :	 0.11577777777777777
Accuracy on the validation :	 0.132
Starting epoch 2/70, LR = 2
Accuracy on the training :	 0.14422222222222222
Accuracy on the validation :	 0.138
Starting epoch 3/70, LR = 2
Accuracy on the training :	 0.15311111111111111
Accuracy on the validation :	 0.166
Starting epoch 4/70, LR = 2
Accuracy on the training :	 0.18222222222222223
Accuracy on the validation :	 0.21
Starting epoch 5/70, LR = 2
Accuracy on the training :	 0.206
Accuracy on the validation :	 0.176
Starting epoch 6/70, LR = 2
Accuracy on the training :	 0.22777777777777777
Accuracy on the validation :	 0.236
Starting epoch 7/70, LR = 2
Accuracy on the training :	 0.2348888888888889
Accuracy on the validation :	 0.254
Starting epoch 8/70, LR = 2
Accuracy on the training :	 0.26066666666666666
Accuracy on the validation :	 0.228
Starting epoch 9/70, LR = 2
Accuracy on the training :	 0.2737777777777778
Accuracy