**Imports**

In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as trans
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim import Adam
from datetime import datetime

**Set Device**

In [2]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

**Model Architecture**

In [3]:
class MainConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(MainConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride = stride, padding = padding)
        self.out = None

    def forward(self, x):
        out = self.conv(x)
        self.out = out      # to store hidden layer ouput
        return out

class ResNetBlock(nn.Module):
    # ResNet basic block
    def __init__(self, in_channels, downsample):
        super(ResNetBlock, self).__init__()
        if downsample:
            stride = 2
            out_channels = 2 * in_channels
        else:
            stride = 1
            out_channels = in_channels
        self.conv1 = MainConv(in_channels, out_channels, 3, stride = stride, padding = 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = MainConv(out_channels, out_channels, 3, stride = 1, padding = 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.convr = nn.Conv2d(in_channels, out_channels, 1, stride = stride, padding = 0) if downsample else None
        self.relu = nn.ReLU(inplace = True)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.convr is not None:
            x = self.convr(x)
        return self.relu(x + y)

class ResNet20(nn.Module):
    def __init__(self, n_classes):
        super(ResNet20, self).__init__()
        blocks = [MainConv(3, 16, 3, stride = 1, padding = 1), nn.ReLU(inplace = True)]
        in_channels = 16
        for i in range(9):
            if i > 0 and i % 3 == 0:
                blocks.append(ResNetBlock(in_channels, True))
                in_channels *= 2
            else:
                blocks.append(ResNetBlock(in_channels, False))
        blocks += [nn.AvgPool2d(8), nn.Flatten(), nn.Linear(in_channels, n_classes)]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)

class PARLModel(nn.Module):
    def __init__(self, n_classes = 10, n_ensemble = 3):
        # n_ensemble: Number of classifiers in the ensemble
        super(PARLModel, self).__init__()
        self.nets = nn.ModuleList([ResNet20(n_classes) for _ in range(n_ensemble)])
        self.n_ensemble = n_ensemble

    def forward(self, x):
        return [net(x) for net in self.nets]

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Linear):
                stdv = 1. / math.sqrt(m.weight.size(1))
                m.weight.data.uniform_(-stdv, stdv)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

**Loss Functions**

In [4]:
class PARLLoss():
    def __init__(self, n_layers, gamma):
        # n_layers: first n_layers number of conv layers are used in the loss function
        # gamma: hyperparameter controlling the weightage of the penalty term
        if n_layers <= 0:
            raise ValueError('PARLLoss.n_layers must be a positive integer.')
        self.n_layers = n_layers
        self.gamma = gamma
        self.loss_fn = nn.CrossEntropyLoss()
        self.model = None
        self.n_ensemble = None
        self.conv_layers = None

    def init_loss_calculator(self, model):
        self.model = model
        self.n_ensemble = model.n_ensemble
        ensemble = [model.nets[i].blocks for i in range(model.n_ensemble)]
        conv_layers = []
        k = 0
        for i in range(len(ensemble[0])):
            if isinstance(ensemble[0][i], MainConv):
                conv_layers.append([ensemble[j][i] for j in range(model.n_ensemble)])
                k += 1
                if k == self.n_layers:
                    break
            elif isinstance(ensemble[0][i], ResNetBlock):
                conv_layers.append([ensemble[j][i].conv1 for j in range(model.n_ensemble)])
                k += 1
                if k == self.n_layers:
                    break
                conv_layers.append([ensemble[j][i].conv2 for j in range(model.n_ensemble)])
                k += 1
                if k == self.n_layers:
                    break
        self.conv_layers = conv_layers
        print(f'Successfully initialized PARLLoss object. Using {len(self.conv_layers)} layers\n')

    def cosine_sim(self, x, y):
        return (x * y).sum() / ((torch.norm(x, p = 2) * torch.norm(y, p = 2)).item() + 1e-12)

    def correlation(self, x, y):
        n = x.size(dim = -1)
        x_centered = x - torch.mean(x, dim = -1, keepdim = True)
        y_centered = y - torch.mean(y, dim = -1, keepdim = True)
        cov_xy = torch.sum(x_centered * y_centered, dim = -1) / n
        corr_xy = cov_xy / (torch.std(x, dim = -1) * torch.std(y, dim = -1) + 1e-12)
        return corr_xy.mean()

    def __call__(self, X, y_true, train = True):
        model_out = self.model(X)
        batch_size = X.shape[0]
        corr_sum = 0
        cosm_sum = 0
        for i in range(len(self.conv_layers)):
            grads = []
            for j in range(self.n_ensemble):
                self.conv_layers[i][j].out.backward(torch.ones_like(self.conv_layers[i][j].out), retain_graph = True, create_graph = train)
                grads.append(torch.cat([self.conv_layers[k][j].conv.weight.grad.clone().reshape(-1) for k in range(i + 1)]))
            self.model.zero_grad()
            for j in range(self.n_ensemble - 1):
                for k in range(j + 1, self.n_ensemble):
                    f1 = torch.permute(self.conv_layers[i][j].out, (1, 2, 3, 0)).reshape(-1, batch_size)
                    f2 = torch.permute(self.conv_layers[i][k].out, (1, 2, 3, 0)).reshape(-1, batch_size)
                    corr_sum += self.correlation(f1, f2).mean()
                    cosm_sum += self.cosine_sim(grads[j], grads[k]).clone()
            del grads
            torch.cuda.empty_cache()
        loss = 0
        for y_pred in model_out:
              loss += self.loss_fn(y_pred, y_true)
        loss /= self.n_ensemble

        return loss + (self.gamma * cosm_sum * corr_sum) / self.n_layers

**Utils**

In [5]:
def get_dataloader(dataset):
    if dataset == 'CIFAR-10':
        n_classes = 10
        train_mean_std = [(0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)]
        test_mean_std = [(0.4942, 0.4851, 0.4504), (0.2467, 0.2429, 0.2616)]
        dataset_obj = datasets.CIFAR10
    elif dataset == 'CIFAR-100':
        n_classes = 100
        train_mean_std = [(0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762)]
        test_mean_std = [(0.5088, 0.4874, 0.4419), (0.2683, 0.2574, 0.2771)]
        dataset_obj = datasets.CIFAR100

    train_transforms = trans.Compose([
        trans.ToTensor(),
        trans.Normalize(*train_mean_std),
        trans.RandomCrop(32, padding = 4, padding_mode = 'edge'),
        trans.RandomHorizontalFlip()
    ])
    test_transforms = trans.Compose([
        trans.ToTensor(),
        trans.Normalize(*test_mean_std)
    ])
    train_data = dataset_obj(root = dataset, train = True, download = True, transform = train_transforms)
    test_data  = dataset_obj(root = dataset, train = False, download = True, transform = test_transforms)
    train_dataloader = DataLoader(train_data, batch_size = 64, shuffle = True, num_workers = 4)
    test_dataloader  = DataLoader(test_data, batch_size = 64, num_workers = 4)

    return n_classes, train_dataloader, test_dataloader

def train(model, train_dataloader, val_dataloader, epochs, train_loss_fn, scheduler, optimizer, checkpoint):
    model.train()
    train_loss_fn.init_loss_calculator(model)
    n_batch = len(train_dataloader)
    best_val_loss = np.inf

    # train loop
    for epoch in range(1, epochs + 1):
        n_sample_seen = 0
        avg_train_loss = 0
        start_time = datetime.now()
        for batch, (x_train, y_train) in enumerate(train_dataloader):
            X, y_true = x_train.to(device), y_train.to(device)
            loss = train_loss_fn(X, y_true, train = True)
            avg_train_loss = avg_train_loss * n_sample_seen + loss.item()
            n_sample_seen += len(y_train)
            avg_train_loss /= n_sample_seen
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            for param in model.parameters():
                param.grad = None
            end_time = datetime.now()
            print('\r', end = '')
            print(f"Epoch: {epoch}/{epochs} | Batch: {batch + 1}/{n_batch} | Learning_Rate: {np.around(scheduler.get_last_lr()[0], 6)} | Training_Loss: {np.around(avg_train_loss, 6)} | Elapsed_Time: {np.around((end_time - start_time).total_seconds(), 1)} s", end = '', flush = True)
        scheduler.step()

        # validation
        avg_val_loss = 0
        n_sample_seen = 0
        for x_val, y_val in val_dataloader:
            X, y_true = x_val.to(device), y_val.to(device)
            loss = train_loss_fn(X, y_true, train = False)
            avg_val_loss = avg_val_loss * n_sample_seen + loss.item()
            n_sample_seen += len(y_val)
            avg_val_loss /= n_sample_seen

        if avg_val_loss < best_val_loss:
            torch.save(model.state_dict(), checkpoint)
            print(f'\nValidation loss improved from {np.around(best_val_loss, 6)} to {np.around(avg_val_loss, 6)}. Model saved as {checkpoint}.')
            best_val_loss = avg_val_loss
        else:
            print(f'\nValidation loss did not improve from {np.around(best_val_loss, 6)}.')

def evaluate_models(model, dataloader, show_indiv = False):
    model.eval()
    acc = [0 for _ in range(model.n_ensemble + 1)]
    n = 0
    for x_test, y_test in dataloader:
        X, y_true = x_test.to(device), y_test.to(device)
        y_pred_ens = model(X)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y_pred = torch.mode(torch.stack(y_pred_ens, 1)).values
        acc[0] += torch.sum(torch.eq(y_pred, y_true)).item()
        for i in range(1, model.n_ensemble + 1):
            acc[i] += torch.sum(torch.eq(y_pred_ens[i - 1], y_true)).item()
        n += len(y_test)

    print(f'Combined Ensemble Accuracy: {np.around(100 * acc[0] / n, 2)}%')
    if show_indiv:
        for i in range(1, model.n_ensemble + 1):
            print(f'Classifier-{i} Clean Accuracy: {np.around(100 * acc[i] / n, 2)}%')

**PARL Training Specification**

In [8]:
n_ensemble = 3 # Number of classifiers in the ensemble
n_layers = 5 # Number of initial conv layers to be considered in PARL loss
dataset = 'CIFAR-100' # CIFAR-10, CIFAR-100
gamma = 0.25 # Hyperparameter to control the relative importance of the penalty term in PARL loss
epochs = 50 # training epochs

**Train**

In [None]:
n_classes, train_dataloader, test_dataloader = get_dataloader(dataset)
model = PARLModel(n_classes = n_classes, n_ensemble = n_ensemble).to(device)
model.init_weights()
train_loss_fn = PARLLoss(n_layers = n_layers, gamma = gamma)
optimizer = Adam(model.parameters(), lr = 0.001)
scheduler = MultiStepLR(optimizer, [epochs // 2, 4 * epochs // 5], gamma = 0.1)
checkpoint = f'PARL{n_layers}_{dataset}_ResNet20_{n_ensemble}_{epochs}epochs_gamma_{gamma}.pth'

# train
train(model, train_dataloader, test_dataloader, epochs, train_loss_fn, scheduler, optimizer, checkpoint)

# evaluate
model.load_state_dict(torch.load(checkpoint))
evaluate_models(model, test_dataloader)