In [None]:
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
import numpy as np
import copy
from torch import nn
from torch.utils.data import DataLoader, Dataset

In [None]:
from torch.utils import data
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

In [None]:
from tqdm import tqdm

In [None]:
import pickle

# Dataset loading and preprocessing


In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

In [None]:
np.random.seed(100)
torch.manual_seed(100)

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

In [None]:
print(len(trainset))

In [None]:
def noniid(dataset, num_users=100, num_shards=300, num_imgs=200):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = np.array(dataset.targets)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign shards/client
    shard = int(num_shards/num_users)  # number of shards each user is assigned
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, shard, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users

In [None]:
user_groups = noniid(trainset, 25, 250, 200)
# 25 users, each user will have 2000 data points

In [None]:
print(len(user_groups[0]))

## We have 4 users, with noniid data


In [None]:
np.random.seed(100)
torch.manual_seed(100)

In [None]:
user1_idxs = user_groups[0]
user1_idxs = [int(i) for i in user1_idxs]
user2_idxs = user_groups[1]
user2_idxs = [int(i) for i in user2_idxs]
user3_idxs = user_groups[2]
user3_idxs = [int(i) for i in user3_idxs]
user4_idxs = user_groups[3]
user4_idxs = [int(i) for i in user4_idxs]

In [None]:
print(len(user1_idxs))

In [None]:
print(np.unique(np.array(trainset.targets)[user1_idxs]))
print(np.unique(np.array(trainset.targets)[user2_idxs]))
print(np.unique(np.array(trainset.targets)[user3_idxs]))
print(np.unique(np.array(trainset.targets)[user4_idxs]))

In [None]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image.clone().detach(), label

In [None]:
trainloader_1 = DataLoader(DatasetSplit(trainset, user1_idxs), batch_size=200, shuffle=True)
trainloader_2 = DataLoader(DatasetSplit(trainset, user2_idxs), batch_size=200, shuffle=True)
trainloader_3 = DataLoader(DatasetSplit(trainset, user3_idxs), batch_size=200, shuffle=True)
trainloader_4 = DataLoader(DatasetSplit(trainset, user4_idxs), batch_size=200, shuffle=True)

In [None]:
trainloaders = [trainloader_1,trainloader_2,trainloader_3,trainloader_4]

# Network architecture

## MLP architecture

In [None]:
class low_MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden=256, dim_out=128):
        super().__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.relu(x)

In [None]:
class top_MLP(nn.Module):
    def __init__(self, dim_in=128, dim_hidden=128, dim_out=10):
        super().__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = x.view(-1, x.shape[0]*x.shape[1])
        x = self.layer_input(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

## CNN architecture

In [None]:
class low_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 32, 5)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 5 * 5, 384)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Debug input shape
        assert x.shape[1:] == (3, 32, 32), f"Unexpected input shape: {x.shape}"

        x = self.pool(self.relu(self.conv1_bn(self.conv1(x))))
        x = self.pool(self.relu(self.conv2_bn(self.conv2(x))))

        # Debug reshaping step
        assert x.shape[1:] == (32, 5, 5), f"Unexpected shape before view: {x.shape}"
        x = x.view(-1, 32 * 5 * 5)
        x = self.relu(self.fc1(x))
        return x

In [None]:
class top_CNN(nn.Module):
    # not actually a CNN, but for convenience, since this is used for CNN experiments
    def __init__(self, dim_in=384, dim_hidden=192, dim_out=10):
        super().__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = x.view(-1, x.shape[0]*x.shape[1])
        x = self.layer_input(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
img_size = 1
for x in trainset[0][0].shape:
    img_size *= x

print(img_size)

# MGDA functions

In [None]:
!pip install quadprog

In [None]:
import quadprog

In [None]:
def solve_w(U):
    # U is list of gradients (stored as list of tensors) from n users
    # That's why the following code might seem a bit clumsy (e.g. not using 2d matrix operation directly)

    n = len(U)
    K = np.eye(n,dtype=float)
    for i in range(0,n):
        for j in range(0,n):
            K[i,j] = 0
            for t in range(len(U[i])):
                K[i,j] += torch.mul(U[i][t],U[j][t]).sum()

    Q = 0.5 *(K + K.T)
    p = np.zeros(n,dtype=float)
    a = np.ones(n,dtype=float).reshape(-1,1)
    Id = np.eye(n,dtype=float)
    A = np.concatenate((a,Id),axis=1)
    b = np.zeros(n+1)
    b[0] = 1.
    # grad = np.zeros(d,dtype=float) # d is not defined
    # # grad = np.zeros(n,dtype=float)
    try:
        alpha = quadprog.solve_qp(Q,p,A,b)[0]
    except ValueError as v:
        print('MGDA stops since the min norm element is zero')
    return alpha


# RP-MGDA on PFL

## Setting

In [None]:
epochs = 50 # Change to 50 for warm start
lr = 0.005
loss = nn.CrossEntropyLoss()

## Train (RP-MGDA)

In [None]:
torch.manual_seed(100) # 42,100

top_net = top_CNN().to(device)
model1 = low_CNN().to(device)
model2 = low_CNN().to(device)
model3 = low_CNN().to(device)
model4 = low_CNN().to(device)

models = [model1,model2,model3,model4]

top_net.train()
model1.train()
model2.train()
model3.train()
model4.train()
loss_list = []

lr=lr

for epoch in tqdm(range(epochs)):
    if epoch % 200 == 199:
        lr *= 0.9
    # Back up the previous top_net parameters
    old_weights = copy.deepcopy(top_net.state_dict())

    # Re-init grads and losses
    grads = []
    losses = []

    for i in range(len(models)):
        for batch_idx, (X, y) in enumerate(trainloaders[i]):
            X, y = X.to(device), y.to(device)

            low_model = models[i]
            h = low_model(X)
            output = top_net(h)
            l = loss(output,y)

            l.backward()
            with torch.no_grad():
                # Store top_net gradient for later MGDA
                gradient = []
                for p in top_net.parameters():
                    gradient.append(p.grad.clone())
                grads.append(gradient)

                # Update lower model
                for p in low_model.parameters():
                    p -= p.grad * lr
                low_model.zero_grad()
                top_net.zero_grad()

                losses.append(l.item())

    gradient_coefficients = solve_w(grads)
    # print("gradient coefficients for epoch ",epoch, "is ", gradient_coefficients)

    # Update top_net model
    with torch.no_grad():
        for i, key in enumerate(old_weights.keys()):
            for j in range(len(gradient_coefficients)):
                old_weights[key] -= lr * gradient_coefficients[j] * grads[j][i]

    top_net.load_state_dict(old_weights)

    print(losses)

    loss_list.append(losses)

In [None]:
print(loss_list[-1])

In [None]:
top_net.zero_grad()
model1.zero_grad()
model2.zero_grad()
model3.zero_grad()
model4.zero_grad()

## Save RPMGDA results

In [None]:
import pickle

In [None]:
with open('CIFAR10_seed100_RPMGDA_EP5000_lr0p005.pickle','wb') as f:
    pickle.dump(loss_list,f)

# MGDA on PFL


## MGDA implementation

In [None]:
def augment_gradient(top_grad, low_grad, idx):
    """
    top_grad = [topnet_layer1_w,topnet_layer1_b,topnet_layer2_w,topnet_layer2_b,...] etc.
    low_grad = [lownet_layer1_w,lownet_layer1_b,lownet_layer2_w,lownet_layer2_b,...] etc.
    idx = index of the user (0,1,2,3), used in trainloaders[idx]

    Output: augmented gradient, used for later on MGDA aggregation
    """
    pad_zeros_grad = [0] * len(low_grad)
    for i in range(len(low_grad)):
        pad_zeros_grad[i] = torch.zeros_like(low_grad[i])

    if idx == 0:
        augmented_grad = top_grad + low_grad + pad_zeros_grad + pad_zeros_grad + pad_zeros_grad
    elif idx == 1:
        augmented_grad = top_grad + pad_zeros_grad + low_grad + pad_zeros_grad + pad_zeros_grad
    elif idx == 2:
        augmented_grad = top_grad + pad_zeros_grad + pad_zeros_grad + low_grad + pad_zeros_grad
    elif idx == 3:
        augmented_grad = top_grad + pad_zeros_grad + pad_zeros_grad + pad_zeros_grad + low_grad
    else:
        print("Error in augmenting gradient, index is not valid: ", idx)

    return augmented_grad

In [None]:
def solve_padded_w(U):
    # U is list of gradients (stored as list of tensors) from n users

    n = len(U)
    K = np.eye(n,dtype=float)
    for i in range(0,n):
        for j in range(0,n):
            K[i,j] = 0
            for t in range(len(U[i])):
                K[i,j] += torch.mul(U[i][t],U[j][t]).sum()

    Q = 0.5 *(K + K.T)
    p = np.zeros(n,dtype=float)
    a = np.ones(n,dtype=float).reshape(-1,1)
    Id = np.eye(n,dtype=float)
    A = np.concatenate((a,Id),axis=1)
    b = np.zeros(n+1)
    b[0] = 1.
    # grad = np.zeros(d,dtype=float)
    # # grad = np.zeros(n,dtype=float)
    try:
        alpha = quadprog.solve_qp(Q,p,A,b)[0]
    except ValueError as v:
        print('MGDA stops since the min norm element is zero')
    return alpha

## Setting
Instruction: \
(1) Go to the previous RP-MGDA setting cell, change epochs=50 \
(2) Run Train (RP-MGDA), which warm starts the network weights \
(3) Now we can run the following for MGDA training

In [None]:
epochs = 5000
lr = 0.005
loss = nn.CrossEntropyLoss()

## Train (MGDA)

In [None]:
loss_list = []

for epoch in tqdm(range(epochs)):
    if epoch % 200 == 149:
        lr *= 0.9
    # Back up the previous top_net parameters
    old_topnet_weights = copy.deepcopy(top_net.state_dict())

    # Re-init grads and losses
    grads = []
    losses = []

    for i in range(len(models)):
        for batch_idx, (X, y) in enumerate(trainloaders[i]):
            X, y = X.to(device), y.to(device)

            low_model = models[i]
            h = low_model(X)
            output = top_net(h)
            l = loss(output,y)

            l.backward()

            with torch.no_grad():
                # Store top_net gradient for later MGDA
                topnet_gradient = []
                for p in top_net.parameters():
                    topnet_gradient.append(p.grad.clone())

                lowernet_gradient = []
                for p in low_model.parameters():
                    lowernet_gradient.append(p.grad.clone())

                augmented_gradient = augment_gradient(topnet_gradient, lowernet_gradient, i)

                grads.append(augmented_gradient)

                losses.append(l.item())


    # MGDA
    gradient_coefficients = solve_padded_w(grads)
    # print("gradient coefficients for epoch ",epoch, "is ", gradient_coefficients)

    # Update lower models
    for i in range(len(models)):
        with torch.no_grad():
            for p in models[i].parameters():
                p -= p.grad * lr * gradient_coefficients[i]
            models[i].zero_grad()
            top_net.zero_grad()



    # Update top_net model
    with torch.no_grad():
        for i, key in enumerate(old_topnet_weights.keys()):
            for j in range(len(gradient_coefficients)):
                old_topnet_weights[key] -= lr * gradient_coefficients[j] * grads[j][i]

    top_net.load_state_dict(old_topnet_weights)



    print(losses)

    loss_list.append(losses)



## Save MGDA results

In [None]:
with open('CIFAR10_seed100_MGDA_EP5000_lr0p005.pickle','wb') as f:
    pickle.dump(loss_list,f)