In [None]:
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
import numpy as np
import copy
from torch import nn
from torch.utils.data import DataLoader, Dataset

In [None]:
from torch.utils import data
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:
from tqdm import tqdm
import copy

# Dataset loading and preprocessing

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

In [None]:
top_categories = ['non_living', 'living']

mid_categories = ['ground_vehicle', 'non_ground_vehicle', 'land_animals', 'non_land_animals']

classes = ['plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image.clone().detach(), label

In [None]:
np.random.seed(100)
torch.manual_seed(100)

In [None]:
# Get the labels of the original training set
original_trainset_labels = np.array(trainset.targets)

# Calculate the number of samples per class in the subset
samples_per_class = len(trainset) // 25 // len(classes)

# Initialize an empty list to store the subset indices
subset_indices = []

# Iterate over each class and sample indices
for class_idx in range(len(classes)):
    # Get the indices of samples belonging to the current class
    class_indices = np.where(original_trainset_labels == class_idx)[0]

    # Randomly sample indices from the current class
    class_subset_indices = np.random.choice(class_indices, size=samples_per_class, replace=False)

    # Append the sampled indices to the subset indices list
    subset_indices.extend(class_subset_indices)

# Create a subset dataset using the sampled indices
subset_trainset = DatasetSplit(trainset, subset_indices)

# Create a DataLoader for the subset
subset_trainloader = DataLoader(subset_trainset, batch_size=2000, shuffle=True, num_workers=0)

In [None]:
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

# testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
def class_to_top_cat(class_label):
    # input is a index between 0 to 9
    class_label = int(class_label)
    if class_label in [0,1,8,9]:
        return 0
    elif class_label in [2,3,4,5,6,7]:
        return 1
    else:
        print("Class label is not valid")
        return -1

def class_to_mid_cat(class_label):
    # input is a index between 0 to 9
    class_label = int(class_label)
    if class_label in [1,9]:
        return 0
    elif class_label in [0,8]:
        return 1
    elif class_label in [3,4,5,7]:
        return 2
    elif class_label in [2,6]:
        return 3
    else:
        print("Class label is not valid")
        return -1

In [None]:
# sanity check
for i in range(0,10):
    print("a ", classes[i], " is a ", top_categories[class_to_top_cat(i)], mid_categories[class_to_mid_cat(i)])

# MGDA functions

In [None]:
!pip install quadprog

In [None]:
import quadprog

In [None]:
def solve_w(U):
    # U is list of gradients (stored as list of tensors) from n users
    # That's why the following code might seem a bit clumsy (e.g. not using 2d matrix operation directly)

    n = len(U)
    K = np.eye(n,dtype=float)
    for i in range(0,n):
        for j in range(0,n):
            K[i,j] = 0
            for t in range(len(U[i])):
                K[i,j] += torch.mul(U[i][t],U[j][t]).sum()

    Q = 0.5 *(K + K.T)
    p = np.zeros(n,dtype=float)
    a = np.ones(n,dtype=float).reshape(-1,1)
    Id = np.eye(n,dtype=float)
    A = np.concatenate((a,Id),axis=1)
    b = np.zeros(n+1)
    b[0] = 1.
    # grad = np.zeros(d,dtype=float) # d is not defined
    # # grad = np.zeros(n,dtype=float)
    try:
        alpha = quadprog.solve_qp(Q,p,A,b)[0]
    except ValueError as v:
        print('MGDA stops since the min norm element is zero')
    return alpha

In [None]:
def solve_padded_w(U):
    # U is list of gradients (stored as list of tensors) from n users

    n = len(U)
    K = np.eye(n,dtype=float)
    for i in range(0,n):
        for j in range(0,n):
            K[i,j] = 0
            for t in range(len(U[i])):
                K[i,j] += torch.mul(U[i][t],U[j][t]).sum()

    Q = 0.5 *(K + K.T)
    p = np.zeros(n,dtype=float)
    a = np.ones(n,dtype=float).reshape(-1,1)
    Id = np.eye(n,dtype=float)
    A = np.concatenate((a,Id),axis=1)
    b = np.zeros(n+1)
    b[0] = 1.
    # grad = np.zeros(d,dtype=float)
    # # grad = np.zeros(n,dtype=float)
    try:
        alpha = quadprog.solve_qp(Q,p,A,b)[0]
    except ValueError as v:
        print('MGDA stops since the min norm element is zero')
    return alpha


# Model Architecture

## CNN Architecture

In [None]:
# 'FEX' means 'Feature EXtract'

class first_FEX_layer(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 5)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()


    def forward(self, x):
        x = self.pool(self.relu(self.conv1_bn(self.conv1(x))))
        x = self.pool(self.relu(self.conv2_bn(self.conv2(x))))
        return x

class second_FEX_layer(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32 * 5 * 5, 384)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 32 * 5 * 5)
        x = self.relu(self.fc1(x))
        return x

class third_FEX_layer(nn.Module):
    def __init__(self, dim_in=384, dim_hidden=192):
        super().__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout()
    def forward(self, x):
        x = self.relu(self.layer_input(x))
        return x


class first_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layerout = nn.Linear(32 * 5 * 5, 2)

    def forward(self, x):
        x = x.view(-1, 32 * 5 * 5)
        x = self.layerout(x)
        return x

class second_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layerout = nn.Linear(384, 4)

    def forward(self, x):
        x = self.layerout(x)
        return x

class third_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layerout = nn.Linear(192, 10)

    def forward(self, x):
        x = self.layerout(x)
        return x

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Training （RP-MGDA）

In [None]:
epochs = 2005 # Change to 5 for warm start
lr = 0.008
loss = nn.CrossEntropyLoss()

In [None]:
torch.manual_seed(100) # 42,100

# 'ftet' means feature extract

ftxt_layer1 = first_FEX_layer().to(device)
ftxt_layer2 = second_FEX_layer().to(device)
ftxt_layer3 = third_FEX_layer().to(device)
classifier1 = first_classifier().to(device)
classifier2 = second_classifier().to(device)
classifier3 = third_classifier().to(device)

ftxt_layer1.train()
ftxt_layer2.train()
ftxt_layer3.train()
classifier1.train()
classifier2.train()
classifier3.train()

loss_list = []

lr=lr

for epoch in tqdm(range(epochs)):
    if epoch % 200 == 199:
        lr *= 0.8

    # Back up the previous shared layers' parameters
    ft1_old_weights = copy.deepcopy(ftxt_layer1.state_dict())
    ft2_old_weights = copy.deepcopy(ftxt_layer2.state_dict())

    # Re-init grads and losses
    grads = []
    losses = []

    # Take a batch
    for batch_idx, (X, y) in enumerate(subset_trainloader):
        with torch.no_grad():
            top_label = y.apply_(class_to_top_cat).to(device)
            mid_label = y.apply_(class_to_mid_cat).to(device)
        X = X.to(device)
        y = y.to(device)

        h1 = ftxt_layer1(X)
        h2 = ftxt_layer2(h1)
        h3 = ftxt_layer3(h2)
        output1 = classifier1(h1)
        output2 = classifier2(h2)
        output3 = classifier3(h3)
        l1 = loss(output1,top_label)
        l2 = loss(output2,mid_label)
        l3 = loss(output3,y)

        # 1st pass, top categories
        l1.backward(retain_graph=True)
        with torch.no_grad():
            # Store ftxt_layer1 gradient for later MGDA
            ft1_gradient_l1 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l1.append(p.grad.clone())

            # Update classifier1
            for p in classifier1.parameters():
                p -= p.grad * lr
            classifier1.zero_grad()
            ftxt_layer1.zero_grad()

            losses.append(l1.item())

        # 2nd pass, mid categories
        l2.backward(retain_graph=True)
        with torch.no_grad():
            # Store ftxt_layer1 and ftxt_later2 gradient for later MGDA
            ft1_gradient_l2 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l2.append(p.grad.clone())

            ft2_gradient_l2 = []
            for p in ftxt_layer2.parameters():
                ft2_gradient_l2.append(p.grad.clone())

            # Update classifier2
            for p in classifier2.parameters():
                p -= p.grad * lr
            classifier2.zero_grad()
            ftxt_layer1.zero_grad()
            ftxt_layer2.zero_grad()

            losses.append(l2.item())

        # 3rd pass, 10 labels
        l3.backward(retain_graph=False)
        with torch.no_grad():
            # Store ftxt_layer1, ftxt_layer2 gradient for later MGDA
            # ftxt_layer3 (in RP-MGDA only) and classifier3 can be updated immediately
            ft1_gradient_l3 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l3.append(p.grad.clone())

            ft2_gradient_l3 = []
            for p in ftxt_layer2.parameters():
                ft2_gradient_l3.append(p.grad.clone())

            # Update ftxt_layer3
            for p in ftxt_layer3.parameters():
                p -= p.grad * lr

            # Update classifier3
            for p in classifier3.parameters():
                p -= p.grad * lr
            ftxt_layer1.zero_grad()
            ftxt_layer2.zero_grad()
            ftxt_layer3.zero_grad()
            classifier3.zero_grad()

            losses.append(l3.item())

        # MGDA on ft1_gradient + ft2_gradient
        # first augment gradient_l1 (ft1_gradient_l1 and zeros_like ft2_gradient (e.g. ft2_gradient_l2))
        zeros_like_ft2 = [torch.zeros_like(p) for p in ft2_gradient_l2]
        grad1 = ft1_gradient_l1 + zeros_like_ft2
        grad2 = ft1_gradient_l2 + ft2_gradient_l2
        grad3 = ft1_gradient_l3 + ft2_gradient_l3
        grads = [grad1, grad2, grad3]
        ft1_grads = [ft1_gradient_l1, ft1_gradient_l2, ft1_gradient_l3]
        ft2_grads = [zeros_like_ft2, ft2_gradient_l2, ft2_gradient_l3]


        with torch.no_grad():
            gradient_coefficients = solve_padded_w(grads)
            print("gradient coefficients for epoch ",epoch, "is ", gradient_coefficients)


        # Update ftxt_layer1
        with torch.no_grad():
            for i, (name, param) in enumerate(ftxt_layer1.named_parameters()):
                for j in range(len(gradient_coefficients)):
                    ft1_old_weights[name] -= lr * gradient_coefficients[j] * ft1_grads[j][i]
        ftxt_layer1.load_state_dict(ft1_old_weights)

        # Update ftxt_layer2
        with torch.no_grad():
            for i, (name, param) in enumerate(ftxt_layer2.named_parameters()):
                for j in range(len(gradient_coefficients)):
                    ft2_old_weights[name] -= lr * gradient_coefficients[j] * ft2_grads[j][i]
        ftxt_layer2.load_state_dict(ft2_old_weights)

    print(losses)

    loss_list.append(losses)






In [None]:
print(losses)

In [None]:
ftxt_layer1.zero_grad()
ftxt_layer2.zero_grad()
ftxt_layer3.zero_grad()
classifier1.zero_grad()
classifier2.zero_grad()
classifier3.zero_grad()

In [None]:
import pickle

In [None]:
with open('seed100_RPMGDA_LADDER_losses_EP2000_lr0p008_warmstart.pickle','wb') as f:
    pickle.dump(loss_list,f)

# Training (MGDA)

## Instruction:
(1) Go to Training (RP-MGDA), set epochs=5 for warmstart, and run\
(2) After that, run the following for MGDA training

In [None]:
epochs = 2000
lr = 0.008
loss = nn.CrossEntropyLoss()

In [None]:
loss_list = []

lr=lr

for epoch in tqdm(range(epochs)):
    if epoch % 200 == 194:
        lr *= 0.8

    # Back up the previous shared layers' parameters
    ft1_old_weights = copy.deepcopy(ftxt_layer1.state_dict())
    ft2_old_weights = copy.deepcopy(ftxt_layer2.state_dict())
    ft3_old_weights = copy.deepcopy(ftxt_layer3.state_dict())

    # Re-init grads and losses
    grads = []
    losses = []

    # Take a batch
    for batch_idx, (X, y) in enumerate(subset_trainloader):
        with torch.no_grad():
            top_label = y.apply_(class_to_top_cat).to(device)
            mid_label = y.apply_(class_to_mid_cat).to(device)
        X = X.to(device)
        y = y.to(device)

        h1 = ftxt_layer1(X)
        h2 = ftxt_layer2(h1)
        h3 = ftxt_layer3(h2)
        output1 = classifier1(h1)
        output2 = classifier2(h2)
        output3 = classifier3(h3)
        l1 = loss(output1,top_label)
        l2 = loss(output2,mid_label)
        l3 = loss(output3,y)

        # 1st pass, top categories

        l1.backward(retain_graph=True)
        with torch.no_grad():
            # Store ftxt_layer1 gradient for later MGDA
            ft1_gradient_l1 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l1.append(p.grad.clone())

            # Update classifier1
            for p in classifier1.parameters():
                p -= p.grad * lr
            classifier1.zero_grad()
            ftxt_layer1.zero_grad()

            losses.append(l1.item())



        # 2nd pass, mid categories
        l2.backward(retain_graph=True)
        with torch.no_grad():
            # Store ftxt_layer1 and ftxt_later2 gradient for later MGDA
            ft1_gradient_l2 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l2.append(p.grad.clone())

            ft2_gradient_l2 = []
            for p in ftxt_layer2.parameters():
                ft2_gradient_l2.append(p.grad.clone())

            # Update classifier2
            for p in classifier2.parameters():
                p -= p.grad * lr
            classifier2.zero_grad()
            ftxt_layer1.zero_grad()
            ftxt_layer2.zero_grad()

            losses.append(l2.item())

        # 3rd pass, label
        l3.backward(retain_graph=False)
        with torch.no_grad():
            # Store ftxt_layer1, ftxt_layer2, ftxt_layer3 gradient for later MGDA
            # Only classifier3 can be updated immediately (this is different compared to RPMGDA)
            ft1_gradient_l3 = []
            for p in ftxt_layer1.parameters():
                ft1_gradient_l3.append(p.grad.clone())

            ft2_gradient_l3 = []
            for p in ftxt_layer2.parameters():
                ft2_gradient_l3.append(p.grad.clone())

            ft3_gradient_l3 = []
            for p in ftxt_layer3.parameters():
                ft3_gradient_l3.append(p.grad.clone())

            # Update classifier3
            for p in classifier3.parameters():
                p -= p.grad * lr

            ftxt_layer1.zero_grad()
            ftxt_layer2.zero_grad()
            ftxt_layer3.zero_grad()
            classifier3.zero_grad()

            losses.append(l3.item())

        # MGDA on ft1_gradient + ft2_gradient + ft3_gradient
        # first augment gradient_l1 (ft1_gradient_l1 and zeros_like ft2_gradient (e.g. ft2_gradient_l2) and more)
        # also augment gradient_l2

        zeros_like_ft2 = [torch.zeros_like(p) for p in ft2_gradient_l2]
        zeros_like_ft3 = [torch.zeros_like(p) for p in ft3_gradient_l3]
        grad1 = ft1_gradient_l1 + zeros_like_ft2 + zeros_like_ft3
        grad2 = ft1_gradient_l2 + ft2_gradient_l2 + zeros_like_ft3
        grad3 = ft1_gradient_l3 + ft2_gradient_l3 + ft3_gradient_l3
        grads = [grad1, grad2, grad3]
        ft1_grads = [ft1_gradient_l1, ft1_gradient_l2, ft1_gradient_l3]
        ft2_grads = [zeros_like_ft2, ft2_gradient_l2, ft2_gradient_l3]
        ft3_grads = [zeros_like_ft3, zeros_like_ft3, ft3_gradient_l3]


        with torch.no_grad():
            gradient_coefficients = solve_padded_w(grads)
            print("gradient coefficients for epoch ",epoch, "is ", gradient_coefficients)


        # Update ftxt_layer1
        with torch.no_grad():
            for i, (name, param) in enumerate(ftxt_layer1.named_parameters()):
                for j in range(len(gradient_coefficients)):
                    ft1_old_weights[name] -= lr * gradient_coefficients[j] * ft1_grads[j][i]
        ftxt_layer1.load_state_dict(ft1_old_weights)

        # Update ftxt_layer2
        with torch.no_grad():
            for i, (name, param) in enumerate(ftxt_layer2.named_parameters()):
                for j in range(len(gradient_coefficients)):
                    ft2_old_weights[name] -= lr * gradient_coefficients[j] * ft2_grads[j][i]
        ftxt_layer2.load_state_dict(ft2_old_weights)

        # Update ftxt_layer3
        with torch.no_grad():
            for i, (name, param) in enumerate(ftxt_layer3.named_parameters()):
                for j in range(len(gradient_coefficients)):
                    ft3_old_weights[name] -= lr * gradient_coefficients[j] * ft3_grads[j][i]
        ftxt_layer3.load_state_dict(ft3_old_weights)

    print(losses)

    loss_list.append(losses)






In [None]:
with open('seed100_MGDA_LADDER_losses_EP2000_lr0p008_warmstart.pickle','wb') as f:
    pickle.dump(loss_list,f)