# Libraries

In [None]:
# google drive access
from google.colab import drive
drive.mount('/content/drive')

%cd drive/MyDrive/Colab Notebooks

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks


In [None]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision
from sklearn.metrics import accuracy_score

# Model

In [None]:
class CCM(nn.Module):
    """
    Channel wise calibration model:
        y = x * Sigmoid(BN(Group_Conv(GAP(x))))
    """
    def __init__(self, in_channel, out_channel):
        super(CCM, self).__init__()
        self.adapool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=in_channel, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        out = self.adapool(x)
        out = self.conv(out)
        out = self.bn(out)
        out = self.sig(out)
        out = x * out
        return out


class SCM(nn.Module):
    """
    Spatial Calibration Module:
        y = x + Group_Conv(x)
    """
    def __init__(self, channel):
        super(SCM, self).__init__()
        self.conv = nn.Conv2d(channel, channel, kernel_size=3,
            stride=1, padding=1, groups=channel, bias=False)

    def forward(self, x):
        out = self.conv(x)
        out = x + out
        return out


class Shortcut(nn.Module):
    """
    Modified skip connections in ResNet 
    with calibration modules.
    """
    def __init__(self, in_channel, out_channel, stride=1):
        super(Shortcut, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)
        self.scm = SCM(out_channel)
        self.bn = nn.BatchNorm2d(out_channel)
        self.ccm = CCM(out_channel, out_channel)

    def forward(self, x):
        out = self.conv(x)
        out = self.scm(out)
        out = self.bn(out)
        out = self.ccm(out)
        return out


class BasicBlock(nn.Module):
    expansion = 1
    """
    Convolutional blocks of ResNet with 
    skip connections and calibration modules.
    """
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        # SCM and CCM for conv1
        self.scm1 = SCM(planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.ccm1 = CCM(planes, planes)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        # SCM and CCM for conv2
        self.scm2 = SCM(planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.ccm2 = CCM(planes, planes)

        self.shortcut = None
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = Shortcut(in_planes, self.expansion*planes, stride)

    def forward(self, x):
        # conv1 forward pass
        out = self.conv1(x)
        out = self.scm1(out)
        out = self.bn1(out)
        out = self.ccm1(out)
        out = F.relu(out)
        # conv2 forward pass
        out = self.conv2(out)
        out = self.scm2(out)
        out = self.bn2(out)
        out = self.ccm2(out)
        if self.shortcut is not None:
            out += self.shortcut(x)
        else:
            out += x
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        # input convolutions
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.scm1 = SCM(64)
        self.bn1 = nn.BatchNorm2d(64)
        self.ccm1 = CCM(64, 64)
        # residual blocks
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # classifier head
        self.classifier = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # forward pass of input conv
        out = self.conv1(x)
        out = self.scm1(out)
        out = self.bn1(out)
        out = self.ccm1(out)
        out = F.relu(out)
        # forward pass of residual blocks
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # forward pass of classifier head
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


def get_resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

# Utils

In [None]:
def freeze_noncalibration_layers(model):
    for child in model.children():
        # Base model convs are frozen
        # print('Child: ', type(child))
        if type(child) == nn.Conv2d:
            froze_this_layer(child, tab='')
        # Search for basic blocks in Sequentials
        elif type(child) == nn.Sequential:
            for child2 in child.children():
                # print('\t Child2: ', type(child2))
                if type(child2) == nn.Conv2d:
                    froze_this_layer(child2, tab='\t ')
                # Search for Shortcuts in Basic Blocks
                elif type(child2) == BasicBlock:
                    for child3 in child2.children():
                        # print('\t\t Child3: ', type(child3))
                        # Base model convs are frozen
                        if type(child3) == nn.Conv2d:
                            froze_this_layer(child3, tab='\t\t ')
                        elif type(child3) == Shortcut:
                            # Base model convs are frozen
                            for child4 in child3.children():
                                # print('\t\t\t Child4: ', type(child4))
                                if type(child4) == nn.Conv2d:
                                    froze_this_layer(child4, tab='\t\t\t ')

def froze_this_layer(layer, tab=''):
    for param in layer.parameters():
        # print(tab + '********* Conv2d frozen *********')
        param.requires_grad = False


def accuracy(y, y_hat):
    y_hat = np.argmax(y_hat, axis=1)
    y = np.squeeze(y)
    acc = accuracy_score(y, y_hat)
    return acc

# Dataset

In [None]:
def get_dataloader(task_no=0, batch_size=128, subset='train', val_ratio=0.1):
    transform = torchvision.transforms.Compose(
                    [torchvision.transforms.ToTensor(),
                     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    # initial task is CIFAR10
    if task_no == 0:
        if subset == 'test':
            dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
            print('Task-{} -->  CIFAR10 ({}) loaded! Num. Samples: {}'.format(task_no, subset, len(dataset)))
            dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=2)

            return dataloader
        else:
            dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
            # split train_validation
            val_len = int(len(dataset) * val_ratio)
            train_len = int(len(dataset) - val_len)
            train_data, val_data = random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(42))

            if subset == 'train':
                dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
                print('Task-{} -->  CIFAR10 ({}) loaded! Num. Samples: {}'.format(task_no, subset, len(train_data)))
            elif subset == 'val':
                dataloader = DataLoader(val_data, batch_size=len(val_data), shuffle=False, num_workers=2)
                print('Task-{} -->  CIFAR10 ({}) loaded! Num. Samples: {}'.format(task_no, subset, len(val_data)))

            return dataloader
    # subsequent tasks are CIFAR100
    else:
        task_no += 2
        if subset == 'test':
            dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True)
            classes = list(dataset.class_to_idx.keys())
            task_classes = classes[task_no*10:10+task_no*10]
            task_class_labels = [int(dataset.class_to_idx[class_]) for class_ in task_classes]
            task_data = CustomCIFAR100(dataset, task_classes, task_class_labels, task_no=task_no, transforms=transform)
            print('Task-{} -->  CIFAR100 ({}) loaded! Num. Samples: {}'.format(task_no-2, subset, len(task_data)))
            dataloader = DataLoader(task_data, batch_size=len(task_data), shuffle=False, num_workers=2)

            return dataloader
        else:
            dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)
            classes = list(dataset.class_to_idx.keys())
            task_classes = classes[task_no*10:10+task_no*10]
            task_class_labels = [int(dataset.class_to_idx[class_]) for class_ in task_classes]
            # split train_validation
            val_len = int(len(dataset) * val_ratio)
            train_len = int(len(dataset) - val_len)
            train_data, val_data = random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(42))

            if subset == 'train':
                task_data = CustomCIFAR100(train_data, task_classes, task_class_labels, task_no=task_no, transforms=transform)
                dataloader = DataLoader(task_data, batch_size=batch_size, shuffle=True, num_workers=2)
                print('Task-{} -->  CIFAR100 ({}) loaded! Num. Samples: {}'.format(task_no-2, subset, len(task_data)))
            elif subset == 'val':
                task_data = CustomCIFAR100(val_data, task_classes, task_class_labels, task_no=task_no, transforms=transform)
                dataloader = DataLoader(task_data, batch_size=len(task_data), shuffle=False, num_workers=2)
                print('Task-{} -->  CIFAR100 ({}) loaded! Num. Samples: {}'.format(task_no-2, subset, len(task_data)))
            
            return dataloader


class CustomCIFAR100(Dataset):
    def __init__(self, dataset, class_names, class_idxs, task_no=0, transforms=None):
        super(CustomCIFAR100, self).__init__()
        self.dataset = dataset
        self.task_no = task_no
        self.class_names = class_names
        self.class_idxs = class_idxs
        self.transforms = transforms
        self.idxs = [idx for idx in range(len(dataset)) if (self.dataset[idx][1] in self.class_idxs)]
        self.labels = [self.dataset[idx][1] - 10*self.task_no for idx in self.idxs]

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img, _ = self.dataset[self.idxs[idx]]
        y = self.labels[idx]

        # convert labels to torch tensor too
        if not isinstance(y, np.ndarray):
            y = np.array(y)
            y = torch.from_numpy(y)

        # apply input image transformations
        if self.transforms:
            img = self.transforms(img)
        return img, y

In [None]:
"""
train_loader0 = get_dataloader(task_no=0, batch_size=128, subset='train')
train_loader4 = get_dataloader(task_no=4, batch_size=64, subset='train')

val_loader0 = get_dataloader(task_no=0, subset='val')
val_loader4 = get_dataloader(task_no=4, subset='val')

test_loader0 = get_dataloader(task_no=0, subset='test')
test_loader4 = get_dataloader(task_no=4, subset='test')
"""

"\ntrain_loader0 = get_dataloader(task_no=0, batch_size=128, subset='train')\ntrain_loader4 = get_dataloader(task_no=4, batch_size=64, subset='train')\n\nval_loader0 = get_dataloader(task_no=0, subset='val')\nval_loader4 = get_dataloader(task_no=4, subset='val')\n\ntest_loader0 = get_dataloader(task_no=0, subset='test')\ntest_loader4 = get_dataloader(task_no=4, subset='test')\n"

# Main

In [None]:
if torch.cuda.is_available():
  print("Cuda (GPU support) is available and enabled!")
  device = torch.device("cuda")
else:
  print("Cuda (GPU support) is not available")
  device = torch.device("cpu")

Cuda (GPU support) is available and enabled!


In [None]:
def train(verbose=True):
    # create ResNet18 model
    model = ResNet(BasicBlock, [2,2,2,2], num_classes=10)
    model = model.to(device)

    # tensorboard logger
    writer = SummaryWriter(log_dir='runs/splitCIFAR_experiment')
    model_path = './models/splitCIFAR_experiment/'

    for task_no in range(6):
        task_tag = 'task_' + str(task_no)
        
        # training parameters
        if task_no == 0:
            batch_size = 64
            lr = 1e-2
            epochs = 15
        else:
            batch_size = 64
            lr = 1e-2
            epochs = 50

        # freeze layers other than CCM and SCM,
        # add new classifier head for next task
        if task_no > 0:
            print('************* Re-calibrating model parameters *************')
            model = model.to('cpu')
            freeze_noncalibration_layers(model)
            model.classifier = nn.Linear(512*BasicBlock.expansion, 10)
            model = model.to(device)

        # check layer freezing
        num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print('************* Task: {}, Number of Trainable Parameters: {} *************'.format(task_no, num_trainable_params))

        # get dataset for task no
        task_train_loader = get_dataloader(task_no=task_no, subset='train', batch_size=batch_size)
        task_val_loader = get_dataloader(task_no=task_no, subset='val')

        # loss, optimizer and lr-scheduler are identical for each task
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr, momentum=0.9)
        if task_no == 0:
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
        else:
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.1)

        # training iterations
        tr_loss_hist, val_loss_hist, val_acc_hist = [], [], []
        num_iter = 0
        for e in range(epochs):
            # put model to training mode after each eval
            model.train()
            # training for an epoch
            tqdm_desc = 'Epoch ' + str(e+1) + '/' + str(epochs)
            for _, data in tqdm(enumerate(task_train_loader), ascii=True, desc=tqdm_desc):
                # get the data
                x, y = data
                x, y = x.to(device), y.to(device)

                # clear gradients
                optimizer.zero_grad()

                # calculate loss
                yhat = model(x)
                loss = criterion(yhat, y)

                # backprop
                loss.backward()

                # update parameters
                optimizer.step()

                # store training loss
                tr_loss_hist.append(loss.item())
                # save training loss per minibatch iterations
                writer.add_scalar('tr_loss/'+task_tag, loss.item(), num_iter)
                num_iter += 1

            # get average training loss for last 5 iterations
            tr_loss = np.sum(tr_loss_hist[:-6:-1])/5
            # update learning-rate
            scheduler.step()
            
            # validation after each epoch
            model.eval()
            with torch.no_grad():
                for i, data in enumerate(task_val_loader):
                    x, y = data
                    x, y = x.to(device), y.to(device)
                    yhat = model(x)
                # compute validation metrics
                val_loss = criterion(yhat, y) .item()
                val_acc = accuracy(y.to('cpu').numpy(), yhat.to('cpu').numpy())

            # save validation scores
            val_loss_hist.append(val_loss)
            val_acc_hist.append(val_acc)
            # save to tensorboard
            writer.add_scalar('val_loss/'+task_tag, val_loss, e)
            writer.add_scalar('val_acc/'+task_tag, val_acc, e)
            # display training info
            if verbose:
                print('tr_loss:{} || val_loss:{} || val_acc:{}'.format(tr_loss, val_loss, val_acc))
                
        # save the model for current task 
        print('************* Saving model for Task-{} *************'.format(task_no))
        model_filename = 'splitcifar_model_task_' + str(task_no) + '.pt'
        torch.save(model, model_path + model_filename)
        print('************* End of Task-{} training *************'.format(task_no))

In [None]:
train()

************* Task: 0, Number of Trainable Parameters: 11231562 *************
Files already downloaded and verified
Task-0 -->  CIFAR10 (train) loaded! Num. Samples: 45000
Files already downloaded and verified
Task-0 -->  CIFAR10 (val) loaded! Num. Samples: 5000


Epoch 1/15: 704it [01:17,  9.07it/s]


tr_loss:1.3788735151290894 || val_loss:1.2268131971359253 || val_acc:0.5414


Epoch 2/15: 704it [01:19,  8.83it/s]


tr_loss:0.9486603379249573 || val_loss:0.9298802614212036 || val_acc:0.671


Epoch 3/15: 704it [01:21,  8.66it/s]


tr_loss:0.7408694982528686 || val_loss:0.7602753639221191 || val_acc:0.735


Epoch 4/15: 704it [01:22,  8.57it/s]


tr_loss:0.5707391738891602 || val_loss:0.6715573072433472 || val_acc:0.7622


Epoch 5/15: 704it [01:22,  8.57it/s]


tr_loss:0.6098137319087982 || val_loss:1.0911964178085327 || val_acc:0.7634


Epoch 6/15: 704it [01:22,  8.57it/s]


tr_loss:0.7191021680831909 || val_loss:0.5962232947349548 || val_acc:0.8084


Epoch 7/15: 704it [01:22,  8.56it/s]


tr_loss:0.36608969420194626 || val_loss:0.5849146246910095 || val_acc:0.8158


Epoch 8/15: 704it [01:22,  8.58it/s]


tr_loss:0.4392354771494865 || val_loss:0.7286567687988281 || val_acc:0.8018


Epoch 9/15: 704it [01:22,  8.58it/s]


tr_loss:0.3092836320400238 || val_loss:0.6813505291938782 || val_acc:0.8176


Epoch 10/15: 704it [01:22,  8.58it/s]


tr_loss:0.20596864223480224 || val_loss:0.7188757061958313 || val_acc:0.8264


Epoch 11/15: 704it [01:22,  8.57it/s]


tr_loss:0.24421435333788394 || val_loss:0.6011775732040405 || val_acc:0.8416


Epoch 12/15: 704it [01:22,  8.57it/s]


tr_loss:0.4499983871355653 || val_loss:0.6028205752372742 || val_acc:0.8426


Epoch 13/15: 704it [01:22,  8.57it/s]


tr_loss:0.16515066046267748 || val_loss:0.6031562089920044 || val_acc:0.8476


Epoch 14/15: 704it [01:22,  8.58it/s]


tr_loss:0.0694453286472708 || val_loss:0.6023108959197998 || val_acc:0.8462


Epoch 15/15: 704it [01:22,  8.56it/s]


tr_loss:0.00943884951993823 || val_loss:0.6078270077705383 || val_acc:0.8462
************* Saving model for Task-0 *************
************* End of Task-0 training *************
************* Re-calibrating model parameters *************
************* Task: 1, Number of Trainable Parameters: 72330 *************
Files already downloaded and verified
Task-1 -->  CIFAR100 (train) loaded! Num. Samples: 4510
Files already downloaded and verified
Task-1 -->  CIFAR100 (val) loaded! Num. Samples: 490


Epoch 1/50: 71it [00:05, 12.91it/s]


tr_loss:1.343268871307373 || val_loss:1.2160282135009766 || val_acc:0.5816326530612245


Epoch 2/50: 71it [00:05, 13.04it/s]


tr_loss:1.106087875366211 || val_loss:1.0680073499679565 || val_acc:0.6448979591836734


Epoch 3/50: 71it [00:05, 12.91it/s]


tr_loss:1.0359187722206116 || val_loss:0.9924860596656799 || val_acc:0.6428571428571429


Epoch 4/50: 71it [00:05, 12.93it/s]


tr_loss:1.0942257404327393 || val_loss:0.9648134708404541 || val_acc:0.6612244897959184


Epoch 5/50: 71it [00:05, 12.83it/s]


tr_loss:0.912748122215271 || val_loss:0.9121530055999756 || val_acc:0.6836734693877551


Epoch 6/50: 71it [00:05, 12.92it/s]


tr_loss:0.9441177845001221 || val_loss:0.9110133051872253 || val_acc:0.6918367346938775


Epoch 7/50: 71it [00:05, 12.98it/s]


tr_loss:0.9396047234535218 || val_loss:0.8752606511116028 || val_acc:0.7183673469387755


Epoch 8/50: 71it [00:05, 13.00it/s]


tr_loss:0.7852700114250183 || val_loss:0.8461692333221436 || val_acc:0.7142857142857143


Epoch 9/50: 71it [00:05, 12.99it/s]


tr_loss:0.8580502867698669 || val_loss:0.8645159602165222 || val_acc:0.7122448979591837


Epoch 10/50: 71it [00:05, 13.02it/s]


tr_loss:0.9600228667259216 || val_loss:0.8479751348495483 || val_acc:0.7183673469387755


Epoch 11/50: 71it [00:05, 12.97it/s]


tr_loss:0.6410321950912475 || val_loss:0.8166452050209045 || val_acc:0.7346938775510204


Epoch 12/50: 71it [00:05, 13.05it/s]


tr_loss:0.6344046711921691 || val_loss:0.8137484192848206 || val_acc:0.7163265306122449


Epoch 13/50: 71it [00:05, 13.03it/s]


tr_loss:0.6845620632171631 || val_loss:0.8023894429206848 || val_acc:0.7224489795918367


Epoch 14/50: 71it [00:05, 12.99it/s]


tr_loss:0.7124549150466919 || val_loss:0.7961039543151855 || val_acc:0.7346938775510204


Epoch 15/50: 71it [00:05, 12.92it/s]


tr_loss:0.6137267410755157 || val_loss:0.777801513671875 || val_acc:0.7346938775510204


Epoch 16/50: 71it [00:05, 12.97it/s]


tr_loss:0.4504363477230072 || val_loss:0.7803478240966797 || val_acc:0.7408163265306122


Epoch 17/50: 71it [00:05, 12.96it/s]


tr_loss:0.5804876923561096 || val_loss:0.7788235545158386 || val_acc:0.7448979591836735


Epoch 18/50: 71it [00:05, 12.96it/s]


tr_loss:0.6229505598545074 || val_loss:0.771792471408844 || val_acc:0.7408163265306122


Epoch 19/50: 71it [00:05, 12.87it/s]


tr_loss:0.6683576583862305 || val_loss:0.7709687948226929 || val_acc:0.7346938775510204


Epoch 20/50: 71it [00:05, 12.91it/s]


tr_loss:0.5700508058071136 || val_loss:0.7626296281814575 || val_acc:0.7387755102040816


Epoch 21/50: 71it [00:05, 13.00it/s]


tr_loss:0.4932330846786499 || val_loss:0.7755418419837952 || val_acc:0.7387755102040816


Epoch 22/50: 71it [00:05, 13.02it/s]


tr_loss:0.6069836139678955 || val_loss:0.7746232151985168 || val_acc:0.7387755102040816


Epoch 23/50: 71it [00:05, 13.05it/s]


tr_loss:0.5297769188880921 || val_loss:0.7770743370056152 || val_acc:0.7326530612244898


Epoch 24/50: 71it [00:05, 13.06it/s]


tr_loss:0.5755434095859527 || val_loss:0.7695490717887878 || val_acc:0.7428571428571429


Epoch 25/50: 71it [00:05, 13.05it/s]


tr_loss:0.5087244093418122 || val_loss:0.7756391167640686 || val_acc:0.7346938775510204


Epoch 26/50: 71it [00:05, 13.05it/s]


tr_loss:0.5080487608909607 || val_loss:0.7718403935432434 || val_acc:0.7408163265306122


Epoch 27/50: 71it [00:05, 13.06it/s]


tr_loss:0.5399941802024841 || val_loss:0.7667236924171448 || val_acc:0.7346938775510204


Epoch 28/50: 71it [00:05, 12.97it/s]


tr_loss:0.5047073185443878 || val_loss:0.7667927145957947 || val_acc:0.7448979591836735


Epoch 29/50: 71it [00:05, 12.93it/s]


tr_loss:0.5318286180496216 || val_loss:0.7719144821166992 || val_acc:0.7408163265306122


Epoch 30/50: 71it [00:05, 12.96it/s]


tr_loss:0.5400548100471496 || val_loss:0.7613691687583923 || val_acc:0.7428571428571429


Epoch 31/50: 71it [00:05, 12.98it/s]


tr_loss:0.4871855556964874 || val_loss:0.7595134377479553 || val_acc:0.746938775510204


Epoch 32/50: 71it [00:05, 13.01it/s]


tr_loss:0.5136809766292572 || val_loss:0.7716318368911743 || val_acc:0.7448979591836735


Epoch 33/50: 71it [00:05, 13.08it/s]


tr_loss:0.5826808750629425 || val_loss:0.7643830180168152 || val_acc:0.746938775510204


Epoch 34/50: 71it [00:05, 13.08it/s]


tr_loss:0.5544237315654754 || val_loss:0.7666661143302917 || val_acc:0.746938775510204


Epoch 35/50: 71it [00:05, 13.00it/s]


tr_loss:0.5596658527851105 || val_loss:0.7715693712234497 || val_acc:0.7448979591836735


Epoch 36/50: 71it [00:05, 12.95it/s]


tr_loss:0.5228244960308075 || val_loss:0.7585274577140808 || val_acc:0.7448979591836735


Epoch 37/50: 71it [00:05, 13.00it/s]


tr_loss:0.5014949560165405 || val_loss:0.7681795358657837 || val_acc:0.7448979591836735


Epoch 38/50: 71it [00:05, 13.00it/s]


tr_loss:0.588269567489624 || val_loss:0.7694231271743774 || val_acc:0.7387755102040816


Epoch 39/50: 71it [00:05, 12.93it/s]


tr_loss:0.5683673739433288 || val_loss:0.7703243494033813 || val_acc:0.7387755102040816


Epoch 40/50: 71it [00:05, 12.94it/s]


tr_loss:0.5470726609230041 || val_loss:0.767113447189331 || val_acc:0.7346938775510204


Epoch 41/50: 71it [00:05, 12.90it/s]


tr_loss:0.511821734905243 || val_loss:0.7634433507919312 || val_acc:0.746938775510204


Epoch 42/50: 71it [00:05, 12.96it/s]


tr_loss:0.5265818417072297 || val_loss:0.7587276697158813 || val_acc:0.7408163265306122


Epoch 43/50: 71it [00:05, 13.00it/s]


tr_loss:0.5886247515678406 || val_loss:0.7740143537521362 || val_acc:0.746938775510204


Epoch 44/50: 71it [00:05, 13.00it/s]


tr_loss:0.40455214977264403 || val_loss:0.7631489038467407 || val_acc:0.7448979591836735


Epoch 45/50: 71it [00:05, 13.04it/s]


tr_loss:0.5077342450618744 || val_loss:0.7588391304016113 || val_acc:0.7408163265306122


Epoch 46/50: 71it [00:05, 13.05it/s]


tr_loss:0.6860312521457672 || val_loss:0.7660296559333801 || val_acc:0.746938775510204


Epoch 47/50: 71it [00:05, 13.07it/s]


tr_loss:0.48674153685569765 || val_loss:0.7658599019050598 || val_acc:0.7448979591836735


Epoch 48/50: 71it [00:05, 12.99it/s]


tr_loss:0.6300402402877807 || val_loss:0.7726731896400452 || val_acc:0.7489795918367347


Epoch 49/50: 71it [00:05, 13.00it/s]


tr_loss:0.426692795753479 || val_loss:0.7696315050125122 || val_acc:0.7448979591836735


Epoch 50/50: 71it [00:05, 12.98it/s]


tr_loss:0.5033047616481781 || val_loss:0.7634559869766235 || val_acc:0.7428571428571429
************* Saving model for Task-1 *************
************* End of Task-1 training *************
************* Re-calibrating model parameters *************
************* Task: 2, Number of Trainable Parameters: 72330 *************
Files already downloaded and verified
Task-2 -->  CIFAR100 (train) loaded! Num. Samples: 4522
Files already downloaded and verified
Task-2 -->  CIFAR100 (val) loaded! Num. Samples: 478


Epoch 1/50: 71it [00:05, 12.84it/s]


tr_loss:1.1640973567962647 || val_loss:1.0842598676681519 || val_acc:0.6129707112970711


Epoch 2/50: 71it [00:05, 12.97it/s]


tr_loss:0.945870840549469 || val_loss:0.9504778981208801 || val_acc:0.6401673640167364


Epoch 3/50: 71it [00:05, 12.89it/s]


tr_loss:0.9502762675285339 || val_loss:0.8590600490570068 || val_acc:0.6736401673640168


Epoch 4/50: 71it [00:05, 12.99it/s]


tr_loss:0.8858914375305176 || val_loss:0.8324997425079346 || val_acc:0.696652719665272


Epoch 5/50: 71it [00:05, 12.96it/s]


tr_loss:0.737954032421112 || val_loss:0.7737587094306946 || val_acc:0.7154811715481172


Epoch 6/50: 71it [00:05, 12.91it/s]


tr_loss:0.7550932884216308 || val_loss:0.7609637379646301 || val_acc:0.7259414225941423


Epoch 7/50: 71it [00:05, 13.00it/s]


tr_loss:0.7184871554374694 || val_loss:0.7234505414962769 || val_acc:0.7447698744769874


Epoch 8/50: 71it [00:05, 12.94it/s]


tr_loss:0.690750241279602 || val_loss:0.6811475157737732 || val_acc:0.7594142259414226


Epoch 9/50: 71it [00:05, 12.96it/s]


tr_loss:0.7241692543029785 || val_loss:0.6704269647598267 || val_acc:0.7740585774058577


Epoch 10/50: 71it [00:05, 12.98it/s]


tr_loss:0.6842054545879364 || val_loss:0.6910880208015442 || val_acc:0.7322175732217573


Epoch 11/50: 71it [00:05, 12.87it/s]


tr_loss:0.6283735632896423 || val_loss:0.6493304967880249 || val_acc:0.7740585774058577


Epoch 12/50: 71it [00:05, 12.92it/s]


tr_loss:0.5659251987934113 || val_loss:0.6395750045776367 || val_acc:0.7824267782426778


Epoch 13/50: 71it [00:05, 12.93it/s]


tr_loss:0.49334356784820554 || val_loss:0.659731388092041 || val_acc:0.7866108786610879


Epoch 14/50: 71it [00:05, 12.97it/s]


tr_loss:0.625325620174408 || val_loss:0.63229900598526 || val_acc:0.7803347280334728


Epoch 15/50: 71it [00:05, 13.01it/s]


tr_loss:0.4276478230953217 || val_loss:0.6826764941215515 || val_acc:0.7845188284518828


Epoch 16/50: 71it [00:05, 12.97it/s]


tr_loss:0.4494836628437042 || val_loss:0.6178355813026428 || val_acc:0.803347280334728


Epoch 17/50: 71it [00:05, 12.97it/s]


tr_loss:0.43498839139938356 || val_loss:0.6282447576522827 || val_acc:0.797071129707113


Epoch 18/50: 71it [00:05, 13.01it/s]


tr_loss:0.4453473389148712 || val_loss:0.6182925701141357 || val_acc:0.7887029288702929


Epoch 19/50: 71it [00:05, 12.96it/s]


tr_loss:0.4631177306175232 || val_loss:0.6298447847366333 || val_acc:0.7928870292887029


Epoch 20/50: 71it [00:05, 13.01it/s]


tr_loss:0.5429121434688569 || val_loss:0.6271401643753052 || val_acc:0.801255230125523


Epoch 21/50: 71it [00:05, 12.98it/s]


tr_loss:0.4154327869415283 || val_loss:0.6117591857910156 || val_acc:0.7949790794979079


Epoch 22/50: 71it [00:05, 12.90it/s]


tr_loss:0.4468234658241272 || val_loss:0.6248109340667725 || val_acc:0.7949790794979079


Epoch 23/50: 71it [00:05, 12.91it/s]


tr_loss:0.41073086857795715 || val_loss:0.6189978718757629 || val_acc:0.799163179916318


Epoch 24/50: 71it [00:05, 12.94it/s]


tr_loss:0.4636756956577301 || val_loss:0.6170995235443115 || val_acc:0.7949790794979079


Epoch 25/50: 71it [00:05, 12.90it/s]


tr_loss:0.49249895215034484 || val_loss:0.61562180519104 || val_acc:0.801255230125523


Epoch 26/50: 71it [00:05, 12.94it/s]


tr_loss:0.4459474146366119 || val_loss:0.6287533640861511 || val_acc:0.7928870292887029


Epoch 27/50: 71it [00:05, 12.93it/s]


tr_loss:0.4178562223911285 || val_loss:0.6170898675918579 || val_acc:0.803347280334728


Epoch 28/50: 71it [00:05, 12.91it/s]


tr_loss:0.4541646599769592 || val_loss:0.6241954565048218 || val_acc:0.799163179916318


Epoch 29/50: 71it [00:05, 12.91it/s]


tr_loss:0.44684510231018065 || val_loss:0.6253374814987183 || val_acc:0.797071129707113


Epoch 30/50: 71it [00:05, 12.91it/s]


tr_loss:0.4899737358093262 || val_loss:0.6265538334846497 || val_acc:0.7907949790794979


Epoch 31/50: 71it [00:05, 12.96it/s]


tr_loss:0.41133266091346743 || val_loss:0.6220964789390564 || val_acc:0.7907949790794979


Epoch 32/50: 71it [00:05, 12.97it/s]


tr_loss:0.3417894124984741 || val_loss:0.6190374493598938 || val_acc:0.7907949790794979


Epoch 33/50: 71it [00:05, 12.99it/s]


tr_loss:0.41157236099243166 || val_loss:0.6210944056510925 || val_acc:0.797071129707113


Epoch 34/50: 71it [00:05, 12.98it/s]


tr_loss:0.3916390061378479 || val_loss:0.6236007213592529 || val_acc:0.7866108786610879


Epoch 35/50: 71it [00:05, 12.98it/s]


tr_loss:0.42758908271789553 || val_loss:0.6165579557418823 || val_acc:0.7928870292887029


Epoch 36/50: 71it [00:05, 13.00it/s]


tr_loss:0.4648649454116821 || val_loss:0.6217816472053528 || val_acc:0.7907949790794979


Epoch 37/50: 71it [00:05, 13.01it/s]


tr_loss:0.4307787775993347 || val_loss:0.6194223165512085 || val_acc:0.7887029288702929


Epoch 38/50: 71it [00:05, 13.02it/s]


tr_loss:0.41340755820274355 || val_loss:0.6285221576690674 || val_acc:0.797071129707113


Epoch 39/50: 71it [00:05, 12.86it/s]


tr_loss:0.4534771263599396 || val_loss:0.6229255795478821 || val_acc:0.7928870292887029


Epoch 40/50: 71it [00:05, 12.91it/s]


tr_loss:0.461103630065918 || val_loss:0.6158715486526489 || val_acc:0.7949790794979079


Epoch 41/50: 71it [00:05, 12.89it/s]


tr_loss:0.6314027726650238 || val_loss:0.6164736151695251 || val_acc:0.799163179916318


Epoch 42/50: 71it [00:05, 12.90it/s]


tr_loss:0.4358730375766754 || val_loss:0.6198263764381409 || val_acc:0.7907949790794979


Epoch 43/50: 71it [00:05, 13.00it/s]


tr_loss:0.38533387184143064 || val_loss:0.6176338791847229 || val_acc:0.799163179916318


Epoch 44/50: 71it [00:05, 12.99it/s]


tr_loss:0.4352666616439819 || val_loss:0.6161808371543884 || val_acc:0.797071129707113


Epoch 45/50: 71it [00:05, 12.97it/s]


tr_loss:0.3862448990345001 || val_loss:0.6191987991333008 || val_acc:0.7928870292887029


Epoch 46/50: 71it [00:05, 12.99it/s]


tr_loss:0.36441047191619874 || val_loss:0.6196251511573792 || val_acc:0.7887029288702929


Epoch 47/50: 71it [00:05, 12.97it/s]


tr_loss:0.39151968955993655 || val_loss:0.6212604641914368 || val_acc:0.799163179916318


Epoch 48/50: 71it [00:05, 13.02it/s]


tr_loss:0.3499688506126404 || val_loss:0.6154609322547913 || val_acc:0.797071129707113


Epoch 49/50: 71it [00:05, 12.93it/s]


tr_loss:0.3867788970470428 || val_loss:0.6142446398735046 || val_acc:0.797071129707113


Epoch 50/50: 71it [00:05, 12.82it/s]


tr_loss:0.4713997721672058 || val_loss:0.6251706480979919 || val_acc:0.7887029288702929
************* Saving model for Task-2 *************
************* End of Task-2 training *************
************* Re-calibrating model parameters *************
************* Task: 3, Number of Trainable Parameters: 72330 *************
Files already downloaded and verified
Task-3 -->  CIFAR100 (train) loaded! Num. Samples: 4497
Files already downloaded and verified
Task-3 -->  CIFAR100 (val) loaded! Num. Samples: 503


Epoch 1/50: 71it [00:05, 12.83it/s]


tr_loss:1.1903029203414917 || val_loss:1.1559218168258667 || val_acc:0.558648111332008


Epoch 2/50: 71it [00:05, 12.99it/s]


tr_loss:1.0701210260391236 || val_loss:1.064353585243225 || val_acc:0.6063618290258449


Epoch 3/50: 71it [00:05, 12.99it/s]


tr_loss:1.0465854406356812 || val_loss:0.9951637983322144 || val_acc:0.6262425447316103


Epoch 4/50: 71it [00:05, 12.96it/s]


tr_loss:0.9041853666305542 || val_loss:0.912245512008667 || val_acc:0.6600397614314115


Epoch 5/50: 71it [00:05, 13.01it/s]


tr_loss:1.0180741906166078 || val_loss:0.9049649834632874 || val_acc:0.6679920477137177


Epoch 6/50: 71it [00:05, 13.01it/s]


tr_loss:0.8066420912742615 || val_loss:0.8536190986633301 || val_acc:0.679920477137177


Epoch 7/50: 71it [00:05, 13.02it/s]


tr_loss:0.8067280411720276 || val_loss:0.8487748503684998 || val_acc:0.6600397614314115


Epoch 8/50: 71it [00:05, 13.07it/s]


tr_loss:0.7659523963928223 || val_loss:0.8244020342826843 || val_acc:0.6938369781312127


Epoch 9/50: 71it [00:05, 13.03it/s]


tr_loss:0.8082302331924438 || val_loss:0.7878863215446472 || val_acc:0.7395626242544732


Epoch 10/50: 71it [00:05, 13.05it/s]


tr_loss:0.6564334869384766 || val_loss:0.8051967620849609 || val_acc:0.6918489065606361


Epoch 11/50: 71it [00:05, 12.93it/s]


tr_loss:0.6501709282398224 || val_loss:0.798811137676239 || val_acc:0.705765407554672


Epoch 12/50: 71it [00:05, 12.98it/s]


tr_loss:0.6661069869995118 || val_loss:0.8006746768951416 || val_acc:0.709741550695825


Epoch 13/50: 71it [00:05, 12.96it/s]


tr_loss:0.6722061991691589 || val_loss:0.7663511037826538 || val_acc:0.7236580516898609


Epoch 14/50: 71it [00:05, 12.99it/s]


tr_loss:0.6489799976348877 || val_loss:0.7236970663070679 || val_acc:0.7375745526838966


Epoch 15/50: 71it [00:05, 13.04it/s]


tr_loss:0.7327689409255982 || val_loss:0.7327306866645813 || val_acc:0.7475149105367793


Epoch 16/50: 71it [00:05, 13.02it/s]


tr_loss:0.5771673798561097 || val_loss:0.7042722702026367 || val_acc:0.7455268389662028


Epoch 17/50: 71it [00:05, 13.05it/s]


tr_loss:0.6303881168365478 || val_loss:0.7065795660018921 || val_acc:0.757455268389662


Epoch 18/50: 71it [00:05, 13.02it/s]


tr_loss:0.7443994581699371 || val_loss:0.6923041343688965 || val_acc:0.7495029821073559


Epoch 19/50: 71it [00:05, 12.99it/s]


tr_loss:0.5245803475379944 || val_loss:0.7018133401870728 || val_acc:0.7495029821073559


Epoch 20/50: 71it [00:05, 13.03it/s]


tr_loss:0.5889911413192749 || val_loss:0.6876325011253357 || val_acc:0.7554671968190855


Epoch 21/50: 71it [00:05, 13.05it/s]


tr_loss:0.477855384349823 || val_loss:0.7015701532363892 || val_acc:0.7554671968190855


Epoch 22/50: 71it [00:05, 12.95it/s]


tr_loss:0.5745182514190674 || val_loss:0.6940501928329468 || val_acc:0.7415506958250497


Epoch 23/50: 71it [00:05, 12.95it/s]


tr_loss:0.5336088895797729 || val_loss:0.6945773363113403 || val_acc:0.7654075546719682


Epoch 24/50: 71it [00:05, 13.00it/s]


tr_loss:0.5913077116012573 || val_loss:0.6998506784439087 || val_acc:0.7514910536779325


Epoch 25/50: 71it [00:05, 12.86it/s]


tr_loss:0.49361287951469424 || val_loss:0.7024661898612976 || val_acc:0.757455268389662


Epoch 26/50: 71it [00:05, 12.98it/s]


tr_loss:0.5733126640319824 || val_loss:0.6922541260719299 || val_acc:0.7594433399602386


Epoch 27/50: 71it [00:05, 12.88it/s]


tr_loss:0.621991765499115 || val_loss:0.6871843934059143 || val_acc:0.7614314115308151


Epoch 28/50: 71it [00:05, 12.93it/s]


tr_loss:0.6201090574264526 || val_loss:0.6879518628120422 || val_acc:0.7554671968190855


Epoch 29/50: 71it [00:05, 12.92it/s]


tr_loss:0.5234213948249817 || val_loss:0.6733434796333313 || val_acc:0.7673956262425448


Epoch 30/50: 71it [00:05, 12.83it/s]


tr_loss:0.5064740478992462 || val_loss:0.7055284380912781 || val_acc:0.7534791252485089


Epoch 31/50: 71it [00:05, 13.01it/s]


tr_loss:0.6887930691242218 || val_loss:0.6726944446563721 || val_acc:0.7753479125248509


Epoch 32/50: 71it [00:05, 13.03it/s]


tr_loss:0.5092963695526123 || val_loss:0.7114103436470032 || val_acc:0.7375745526838966


Epoch 33/50: 71it [00:05, 13.04it/s]


tr_loss:0.5749727129936218 || val_loss:0.6852276921272278 || val_acc:0.7654075546719682


Epoch 34/50: 71it [00:05, 13.00it/s]


tr_loss:0.6621574878692627 || val_loss:0.6697313189506531 || val_acc:0.7733598409542743


Epoch 35/50: 71it [00:05, 13.05it/s]


tr_loss:0.5573899030685425 || val_loss:0.685157299041748 || val_acc:0.757455268389662


Epoch 36/50: 71it [00:05, 13.06it/s]


tr_loss:0.5713557064533233 || val_loss:0.6869186758995056 || val_acc:0.7673956262425448


Epoch 37/50: 71it [00:05, 12.98it/s]


tr_loss:0.5455976784229278 || val_loss:0.682327151298523 || val_acc:0.757455268389662


Epoch 38/50: 71it [00:05, 12.94it/s]


tr_loss:0.5193030953407287 || val_loss:0.6946737766265869 || val_acc:0.7673956262425448


Epoch 39/50: 71it [00:05, 12.92it/s]


tr_loss:0.5726443529129028 || val_loss:0.6775751113891602 || val_acc:0.7614314115308151


Epoch 40/50: 71it [00:05, 12.96it/s]


tr_loss:0.49678426384925845 || val_loss:0.6811855435371399 || val_acc:0.7673956262425448


Epoch 41/50: 71it [00:05, 12.95it/s]


tr_loss:0.5122740745544434 || val_loss:0.6842774748802185 || val_acc:0.7693836978131213


Epoch 42/50: 71it [00:05, 12.95it/s]


tr_loss:0.56253702044487 || val_loss:0.6851699352264404 || val_acc:0.7673956262425448


Epoch 43/50: 71it [00:05, 12.99it/s]


tr_loss:0.5589003503322602 || val_loss:0.6858750581741333 || val_acc:0.7634194831013916


Epoch 44/50: 71it [00:05, 12.93it/s]


tr_loss:0.6141964614391326 || val_loss:0.6968899369239807 || val_acc:0.757455268389662


Epoch 45/50: 71it [00:05, 12.96it/s]


tr_loss:0.6210561692714691 || val_loss:0.6765484809875488 || val_acc:0.7634194831013916


Epoch 46/50: 71it [00:05, 13.00it/s]


tr_loss:0.7209311187267303 || val_loss:0.6745448708534241 || val_acc:0.7654075546719682


Epoch 47/50: 71it [00:05, 13.03it/s]


tr_loss:0.5821575880050659 || val_loss:0.6903067231178284 || val_acc:0.7514910536779325


Epoch 48/50: 71it [00:05, 13.03it/s]


tr_loss:0.4979820430278778 || val_loss:0.6709349155426025 || val_acc:0.7673956262425448


Epoch 49/50: 71it [00:05, 12.93it/s]


tr_loss:0.661302101612091 || val_loss:0.6980167627334595 || val_acc:0.7534791252485089


Epoch 50/50: 71it [00:05, 12.96it/s]


tr_loss:0.4639317452907562 || val_loss:0.6861414313316345 || val_acc:0.7634194831013916
************* Saving model for Task-3 *************
************* End of Task-3 training *************
************* Re-calibrating model parameters *************
************* Task: 4, Number of Trainable Parameters: 72330 *************
Files already downloaded and verified
Task-4 -->  CIFAR100 (train) loaded! Num. Samples: 4507
Files already downloaded and verified
Task-4 -->  CIFAR100 (val) loaded! Num. Samples: 493


Epoch 1/50: 71it [00:05, 12.74it/s]


tr_loss:1.1134942293167114 || val_loss:1.058134913444519 || val_acc:0.6125760649087221


Epoch 2/50: 71it [00:05, 12.93it/s]


tr_loss:0.962991452217102 || val_loss:0.9437594413757324 || val_acc:0.6632860040567952


Epoch 3/50: 71it [00:05, 13.04it/s]


tr_loss:0.7636104345321655 || val_loss:0.9355961084365845 || val_acc:0.6693711967545639


Epoch 4/50: 71it [00:05, 12.98it/s]


tr_loss:0.848019540309906 || val_loss:0.9189106225967407 || val_acc:0.6713995943204868


Epoch 5/50: 71it [00:05, 12.99it/s]


tr_loss:0.8229149103164672 || val_loss:0.8536023497581482 || val_acc:0.6795131845841785


Epoch 6/50: 71it [00:05, 12.95it/s]


tr_loss:0.8718612790107727 || val_loss:0.8547913432121277 || val_acc:0.6855983772819473


Epoch 7/50: 71it [00:05, 12.98it/s]


tr_loss:0.8056899905204773 || val_loss:0.8292611837387085 || val_acc:0.7221095334685599


Epoch 8/50: 71it [00:05, 13.03it/s]


tr_loss:0.7526571154594421 || val_loss:0.8063526749610901 || val_acc:0.7079107505070994


Epoch 9/50: 71it [00:05, 12.86it/s]


tr_loss:0.7449875593185424 || val_loss:0.8121708631515503 || val_acc:0.6997971602434077


Epoch 10/50: 71it [00:05, 12.94it/s]


tr_loss:0.6345494270324707 || val_loss:0.8111351132392883 || val_acc:0.7221095334685599


Epoch 11/50: 71it [00:05, 12.94it/s]


tr_loss:0.6593064665794373 || val_loss:0.803558886051178 || val_acc:0.7261663286004056


Epoch 12/50: 71it [00:05, 12.94it/s]


tr_loss:0.5103090703487396 || val_loss:0.7836365699768066 || val_acc:0.718052738336714


Epoch 13/50: 71it [00:05, 13.03it/s]


tr_loss:0.6021898269653321 || val_loss:0.7534687519073486 || val_acc:0.7383367139959433


Epoch 14/50: 71it [00:05, 12.99it/s]


tr_loss:0.49721076488494875 || val_loss:0.7737118005752563 || val_acc:0.7363083164300203


Epoch 15/50: 71it [00:05, 12.99it/s]


tr_loss:0.4718460559844971 || val_loss:0.7841794490814209 || val_acc:0.7322515212981744


Epoch 16/50: 71it [00:05, 12.99it/s]


tr_loss:0.49366342425346377 || val_loss:0.7743039727210999 || val_acc:0.7342799188640974


Epoch 17/50: 71it [00:05, 13.00it/s]


tr_loss:0.504901385307312 || val_loss:0.7698036432266235 || val_acc:0.7342799188640974


Epoch 18/50: 71it [00:05, 13.07it/s]


tr_loss:0.5206537544727325 || val_loss:0.7715880870819092 || val_acc:0.7403651115618661


Epoch 19/50: 71it [00:05, 13.02it/s]


tr_loss:0.47541981339454653 || val_loss:0.7596243619918823 || val_acc:0.7342799188640974


Epoch 20/50: 71it [00:05, 12.93it/s]


tr_loss:0.5405411660671234 || val_loss:0.7623717188835144 || val_acc:0.7403651115618661


Epoch 21/50: 71it [00:05, 12.90it/s]


tr_loss:0.4662266135215759 || val_loss:0.7579829096794128 || val_acc:0.7342799188640974


Epoch 22/50: 71it [00:05, 12.96it/s]


tr_loss:0.45124499797821044 || val_loss:0.7575616240501404 || val_acc:0.7464503042596349


Epoch 23/50: 71it [00:05, 12.91it/s]


tr_loss:0.498213118314743 || val_loss:0.7459832429885864 || val_acc:0.742393509127789


Epoch 24/50: 71it [00:05, 12.87it/s]


tr_loss:0.520090526342392 || val_loss:0.7526573538780212 || val_acc:0.7484787018255578


Epoch 25/50: 71it [00:05, 13.01it/s]


tr_loss:0.42855467796325686 || val_loss:0.7655755281448364 || val_acc:0.7383367139959433


Epoch 26/50: 71it [00:05, 12.95it/s]


tr_loss:0.43684431314468386 || val_loss:0.7571779489517212 || val_acc:0.744421906693712


Epoch 27/50: 71it [00:05, 12.94it/s]


tr_loss:0.46735326051712034 || val_loss:0.7565144896507263 || val_acc:0.7464503042596349


Epoch 28/50: 71it [00:05, 13.01it/s]


tr_loss:0.5098197937011719 || val_loss:0.75801682472229 || val_acc:0.742393509127789


Epoch 29/50: 71it [00:05, 12.98it/s]


tr_loss:0.45001755356788636 || val_loss:0.7425962090492249 || val_acc:0.742393509127789


Epoch 30/50: 71it [00:05, 12.97it/s]


tr_loss:0.4744725823402405 || val_loss:0.7579812407493591 || val_acc:0.7342799188640974


Epoch 31/50: 71it [00:05, 12.92it/s]


tr_loss:0.37703452706336976 || val_loss:0.7625446319580078 || val_acc:0.7403651115618661


Epoch 32/50: 71it [00:05, 12.88it/s]


tr_loss:0.4160816311836243 || val_loss:0.7593300342559814 || val_acc:0.7363083164300203


Epoch 33/50: 71it [00:05, 12.86it/s]


tr_loss:0.4508142352104187 || val_loss:0.7586148977279663 || val_acc:0.742393509127789


Epoch 34/50: 71it [00:05, 12.77it/s]


tr_loss:0.45379987359046936 || val_loss:0.7524921894073486 || val_acc:0.7484787018255578


Epoch 35/50: 71it [00:05, 12.80it/s]


tr_loss:0.5800617396831512 || val_loss:0.7486752867698669 || val_acc:0.7464503042596349


Epoch 36/50: 71it [00:05, 13.00it/s]


tr_loss:0.45830312967300413 || val_loss:0.7583316564559937 || val_acc:0.7403651115618661


Epoch 37/50: 71it [00:05, 12.95it/s]


tr_loss:0.4143089592456818 || val_loss:0.7500107288360596 || val_acc:0.7403651115618661


Epoch 38/50: 71it [00:05, 12.89it/s]


tr_loss:0.39950591921806333 || val_loss:0.7489092946052551 || val_acc:0.7484787018255578


Epoch 39/50: 71it [00:05, 12.89it/s]


tr_loss:0.5306411385536194 || val_loss:0.7574650049209595 || val_acc:0.742393509127789


Epoch 40/50: 71it [00:05, 12.89it/s]


tr_loss:0.44194323420524595 || val_loss:0.7541794776916504 || val_acc:0.7403651115618661


Epoch 41/50: 71it [00:05, 12.92it/s]


tr_loss:0.4296289950609207 || val_loss:0.7480553388595581 || val_acc:0.7484787018255578


Epoch 42/50: 71it [00:05, 12.98it/s]


tr_loss:0.4619780838489532 || val_loss:0.7479115128517151 || val_acc:0.7525354969574036


Epoch 43/50: 71it [00:05, 12.99it/s]


tr_loss:0.4671792805194855 || val_loss:0.7583875060081482 || val_acc:0.7525354969574036


Epoch 44/50: 71it [00:05, 12.97it/s]


tr_loss:0.3888672947883606 || val_loss:0.7520000338554382 || val_acc:0.742393509127789


Epoch 45/50: 71it [00:05, 13.03it/s]


tr_loss:0.4587083041667938 || val_loss:0.7509186267852783 || val_acc:0.7505070993914807


Epoch 46/50: 71it [00:05, 12.90it/s]


tr_loss:0.42074993848800657 || val_loss:0.753028392791748 || val_acc:0.7464503042596349


Epoch 47/50: 71it [00:05, 12.97it/s]


tr_loss:0.5194195449352265 || val_loss:0.7499547600746155 || val_acc:0.7464503042596349


Epoch 48/50: 71it [00:05, 12.94it/s]


tr_loss:0.5059520840644837 || val_loss:0.7577336430549622 || val_acc:0.742393509127789


Epoch 49/50: 71it [00:05, 12.89it/s]


tr_loss:0.550886595249176 || val_loss:0.7507244944572449 || val_acc:0.7464503042596349


Epoch 50/50: 71it [00:05, 12.92it/s]


tr_loss:0.5153862357139587 || val_loss:0.7556615471839905 || val_acc:0.7464503042596349
************* Saving model for Task-4 *************
************* End of Task-4 training *************
************* Re-calibrating model parameters *************
************* Task: 5, Number of Trainable Parameters: 72330 *************
Files already downloaded and verified
Task-5 -->  CIFAR100 (train) loaded! Num. Samples: 4507
Files already downloaded and verified
Task-5 -->  CIFAR100 (val) loaded! Num. Samples: 493


Epoch 1/50: 71it [00:05, 12.72it/s]


tr_loss:1.194153356552124 || val_loss:1.1042119264602661 || val_acc:0.6308316430020284


Epoch 2/50: 71it [00:05, 13.03it/s]


tr_loss:1.068712842464447 || val_loss:1.0390715599060059 || val_acc:0.6572008113590264


Epoch 3/50: 71it [00:05, 13.03it/s]


tr_loss:1.0013574600219726 || val_loss:0.9251660108566284 || val_acc:0.6896551724137931


Epoch 4/50: 71it [00:05, 12.98it/s]


tr_loss:0.8775189757347107 || val_loss:0.944729745388031 || val_acc:0.6754563894523327


Epoch 5/50: 71it [00:05, 12.87it/s]


tr_loss:0.9739128232002259 || val_loss:0.8913525938987732 || val_acc:0.6835699797160243


Epoch 6/50: 71it [00:05, 12.95it/s]


tr_loss:0.7376175642013549 || val_loss:0.8512121438980103 || val_acc:0.7099391480730223


Epoch 7/50: 71it [00:05, 13.01it/s]


tr_loss:0.9019188165664673 || val_loss:0.8034939169883728 || val_acc:0.7261663286004056


Epoch 8/50: 71it [00:05, 12.90it/s]


tr_loss:0.8977561354637146 || val_loss:0.803005576133728 || val_acc:0.7221095334685599


Epoch 9/50: 71it [00:05, 12.88it/s]


tr_loss:0.7493572473526001 || val_loss:0.8044775128364563 || val_acc:0.7302231237322515


Epoch 10/50: 71it [00:05, 12.85it/s]


tr_loss:0.7594454407691955 || val_loss:0.7866061925888062 || val_acc:0.7383367139959433


Epoch 11/50: 71it [00:05, 12.80it/s]


tr_loss:0.6053801536560058 || val_loss:0.7588420510292053 || val_acc:0.7484787018255578


Epoch 12/50: 71it [00:05, 12.93it/s]


tr_loss:0.6649510383605957 || val_loss:0.7816949486732483 || val_acc:0.7241379310344828


Epoch 13/50: 71it [00:05, 12.98it/s]


tr_loss:0.5269255936145782 || val_loss:0.7822637557983398 || val_acc:0.742393509127789


Epoch 14/50: 71it [00:05, 13.00it/s]


tr_loss:0.6736033797264099 || val_loss:0.7645822167396545 || val_acc:0.7241379310344828


Epoch 15/50: 71it [00:05, 12.96it/s]


tr_loss:0.7836273431777954 || val_loss:0.747050404548645 || val_acc:0.7484787018255578


Epoch 16/50: 71it [00:05, 12.96it/s]


tr_loss:0.554613345861435 || val_loss:0.7214124202728271 || val_acc:0.7484787018255578


Epoch 17/50: 71it [00:05, 13.02it/s]


tr_loss:0.5405933618545532 || val_loss:0.7131307125091553 || val_acc:0.7606490872210954


Epoch 18/50: 71it [00:05, 13.07it/s]


tr_loss:0.5078785955905915 || val_loss:0.7225040793418884 || val_acc:0.7464503042596349


Epoch 19/50: 71it [00:05, 13.05it/s]


tr_loss:0.588909262418747 || val_loss:0.7243214249610901 || val_acc:0.7484787018255578


Epoch 20/50: 71it [00:05, 12.91it/s]


tr_loss:0.6044780015945435 || val_loss:0.7233927845954895 || val_acc:0.7586206896551724


Epoch 21/50: 71it [00:05, 12.88it/s]


tr_loss:0.5479290187358856 || val_loss:0.7177453637123108 || val_acc:0.7586206896551724


Epoch 22/50: 71it [00:05, 12.86it/s]


tr_loss:0.6139064431190491 || val_loss:0.7199260592460632 || val_acc:0.7525354969574036


Epoch 23/50: 71it [00:05, 12.89it/s]


tr_loss:0.5034128427505493 || val_loss:0.7176386713981628 || val_acc:0.7586206896551724


Epoch 24/50: 71it [00:05, 12.85it/s]


tr_loss:0.6621410489082337 || val_loss:0.730292558670044 || val_acc:0.7586206896551724


Epoch 25/50: 71it [00:05, 12.95it/s]


tr_loss:0.5420450985431671 || val_loss:0.7229464650154114 || val_acc:0.7505070993914807


Epoch 26/50: 71it [00:05, 13.00it/s]


tr_loss:0.563895297050476 || val_loss:0.7271140217781067 || val_acc:0.7626774847870182


Epoch 27/50: 71it [00:05, 13.02it/s]


tr_loss:0.5458769500255585 || val_loss:0.717227578163147 || val_acc:0.7565922920892495


Epoch 28/50: 71it [00:05, 12.91it/s]


tr_loss:0.63800008893013 || val_loss:0.7225766777992249 || val_acc:0.7606490872210954


Epoch 29/50: 71it [00:05, 13.02it/s]


tr_loss:0.6740006744861603 || val_loss:0.7218376398086548 || val_acc:0.7565922920892495


Epoch 30/50: 71it [00:05, 12.91it/s]


tr_loss:0.47743244767189025 || val_loss:0.720404326915741 || val_acc:0.7565922920892495


Epoch 31/50: 71it [00:05, 12.88it/s]


tr_loss:0.5687049627304077 || val_loss:0.7253671288490295 || val_acc:0.7565922920892495


Epoch 32/50: 71it [00:05, 12.89it/s]


tr_loss:0.5941947937011719 || val_loss:0.7193524241447449 || val_acc:0.7586206896551724


Epoch 33/50: 71it [00:05, 12.89it/s]


tr_loss:0.39978086948394775 || val_loss:0.7166696786880493 || val_acc:0.7586206896551724


Epoch 34/50: 71it [00:05, 12.81it/s]


tr_loss:0.5438858807086945 || val_loss:0.7172093987464905 || val_acc:0.7586206896551724


Epoch 35/50: 71it [00:05, 12.92it/s]


tr_loss:0.573454988002777 || val_loss:0.7273545861244202 || val_acc:0.7606490872210954


Epoch 36/50: 71it [00:05, 12.92it/s]


tr_loss:0.5186596691608429 || val_loss:0.7190513610839844 || val_acc:0.7586206896551724


Epoch 37/50: 71it [00:05, 12.91it/s]


tr_loss:0.5196307539939881 || val_loss:0.7150562405586243 || val_acc:0.7667342799188641


Epoch 38/50: 71it [00:05, 12.81it/s]


tr_loss:0.6225524187088013 || val_loss:0.715934693813324 || val_acc:0.7626774847870182


Epoch 39/50: 71it [00:05, 12.81it/s]


tr_loss:0.6380015969276428 || val_loss:0.719076931476593 || val_acc:0.7586206896551724


Epoch 40/50: 71it [00:05, 12.92it/s]


tr_loss:0.5198655128479004 || val_loss:0.7145037055015564 || val_acc:0.7565922920892495


Epoch 41/50: 71it [00:05, 12.95it/s]


tr_loss:0.5943482875823974 || val_loss:0.7131211757659912 || val_acc:0.7545638945233266


Epoch 42/50: 71it [00:05, 12.94it/s]


tr_loss:0.5490561425685883 || val_loss:0.7150349617004395 || val_acc:0.7545638945233266


Epoch 43/50: 71it [00:05, 12.95it/s]


tr_loss:0.5197583973407746 || val_loss:0.7167837619781494 || val_acc:0.7626774847870182


Epoch 44/50: 71it [00:05, 12.95it/s]


tr_loss:0.4418071448802948 || val_loss:0.7194328904151917 || val_acc:0.7525354969574036


Epoch 45/50: 71it [00:05, 12.97it/s]


tr_loss:0.6321208655834198 || val_loss:0.7107242941856384 || val_acc:0.7606490872210954


Epoch 46/50: 71it [00:05, 12.93it/s]


tr_loss:0.7186906337738037 || val_loss:0.7209762930870056 || val_acc:0.7647058823529411


Epoch 47/50: 71it [00:05, 13.06it/s]


tr_loss:0.5510615408420563 || val_loss:0.7144817113876343 || val_acc:0.7586206896551724


Epoch 48/50: 71it [00:05, 12.87it/s]


tr_loss:0.5681861698627472 || val_loss:0.7208508849143982 || val_acc:0.7586206896551724


Epoch 49/50: 71it [00:05, 12.88it/s]


tr_loss:0.565522301197052 || val_loss:0.7175218462944031 || val_acc:0.7626774847870182


Epoch 50/50: 71it [00:05, 12.80it/s]


tr_loss:0.5736058533191681 || val_loss:0.7175626754760742 || val_acc:0.7586206896551724
************* Saving model for Task-5 *************
************* End of Task-5 training *************


In [None]:
def test():
    model_path = './models/splitCIFAR_experiment/'
    for task_no in range(6):
        # load models per task
        series_num = 'splitcifar_model_task_' + str(task_no) + '.pt'
        model = torch.load(model_path + series_num)
        model = model.to(device)
        model.eval()
        # get task specific test data
        test_loader = get_dataloader(subset='test', task_no=task_no)
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for _, data in enumerate(test_loader):
                x, y = data
                x, y = x.to(device), y.to(device)
                yhat = model(x)

                test_loss = criterion(yhat, y).item()
                test_acc = accuracy(y.to('cpu').numpy(), yhat.to('cpu').numpy())
            print('Task-{} || test_loss:{} || test_acc:{}'.format(task_no, test_loss, test_acc))

In [None]:
test()

Files already downloaded and verified
Task-0 -->  CIFAR10 (test) loaded! Num. Samples: 10000
Task-0 || test_loss:0.6690012216567993 || test_acc:0.8379
Files already downloaded and verified
Task-1 -->  CIFAR100 (test) loaded! Num. Samples: 1000
Task-1 || test_loss:0.8082564473152161 || test_acc:0.74
Files already downloaded and verified
Task-2 -->  CIFAR100 (test) loaded! Num. Samples: 1000
Task-2 || test_loss:0.6675814986228943 || test_acc:0.784
Files already downloaded and verified
Task-3 -->  CIFAR100 (test) loaded! Num. Samples: 1000
Task-3 || test_loss:0.7235320806503296 || test_acc:0.726
Files already downloaded and verified
Task-4 -->  CIFAR100 (test) loaded! Num. Samples: 1000
Task-4 || test_loss:0.7428212761878967 || test_acc:0.756
Files already downloaded and verified
Task-5 -->  CIFAR100 (test) loaded! Num. Samples: 1000
Task-5 || test_loss:0.7409536242485046 || test_acc:0.751
