In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
import random
import os
import pickle
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Iterator, Tuple, Union, Dict, List
from argparse import ArgumentParser, Namespace
from collections import OrderedDict, Counter
from copy import deepcopy

In [98]:
!pip install fedlab~=1.1.4
from fedlab.utils.serialization import SerializationTool
from fedlab.utils.aggregator import Aggregators
from fedlab.utils.dataset.slicing import noniid_slicing

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [99]:
!pip install rich
import rich
from rich.console import Console
from rich.progress import track

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [100]:
!pip install path
from path import Path

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


###MOUNT

In [101]:
from google.colab import drive
drive.mount("/content/drive")
%cd '/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg


In [102]:
!pwd

/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg


###DATA/DATASET

In [103]:
class MNISTDataset(Dataset):
    def __init__(
        self,
        subset=None,
        data=None,
        targets=None,
        transform=None,
        target_transform=None,
    ) -> None:
        self.transform = transform
        self.target_transform = target_transform
        if (data is not None) and (targets is not None):
            self.data = data.unsqueeze(1)
            self.targets = targets
        elif subset is not None:
            self.data = torch.stack(
                list(
                    map(
                        lambda tup: tup[0]
                        if isinstance(tup[0], torch.Tensor)
                        else torch.tensor(tup[0]),
                        subset,
                    )
                )
            )
            self.targets = torch.stack(
                list(
                    map(
                        lambda tup: tup[1]
                        if isinstance(tup[1], torch.Tensor)
                        else torch.tensor(tup[1]),
                        subset,
                    )
                )
            )
        else:
            raise ValueError(
                "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor]  targets: List[Tensor]"
            )

    def __getitem__(self, index):
        data, targets = self.data[index], self.targets[index]

        if self.transform is not None:
            data = self.transform(self.data[index])

        if self.target_transform is not None:
            targets = self.target_transform(self.targets[index])

        return data, targets

    def __len__(self):
        return len(self.targets)


class CIFARDataset(Dataset):
    def __init__(
        self,
        subset=None,
        data=None,
        targets=None,
        transform=None,
        target_transform=None,
    ) -> None:
        self.transform = transform
        self.target_transform = target_transform
        if (data is not None) and (targets is not None):
            self.data = data.unsqueeze(1)
            self.targets = targets
        elif subset is not None:
            self.data = torch.stack(
                list(
                    map(
                        lambda tup: tup[0]
                        if isinstance(tup[0], torch.Tensor)
                        else torch.tensor(tup[0]),
                        subset,
                    )
                )
            )
            self.targets = torch.stack(
                list(
                    map(
                        lambda tup: tup[1]
                        if isinstance(tup[1], torch.Tensor)
                        else torch.tensor(tup[1]),
                        subset,
                    )
                )
            )
        else:
            raise ValueError(
                "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor]  targets: List[Tensor]"
            )

    def __getitem__(self, index):
        img, targets = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(self.data[index])

        if self.target_transform is not None:
            targets = self.target_transform(self.targets[index])

        return img, targets

    def __len__(self):
        return len(self.targets)


###PREPROCESS

In [104]:
#CURRENT_DIR = Path(__file__).parent.abspath()
CURRENT_DIR = "/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg"

DATASET = {
    "mnist": (MNIST, MNISTDataset),
    "cifar": (CIFAR10, CIFARDataset),
}

MEAN = {
    "mnist": (0.1307,),
    "cifar": (0.4914, 0.4822, 0.4465),
}

STD = {
    "mnist": (0.3015,),
    "cifar": (0.2023, 0.1994, 0.2010),
}

def preprocess(args: Namespace) -> None:
    dataset_dir = CURRENT_DIR + "/" + args.dataset
    pickles_dir = CURRENT_DIR + "/" + args.dataset + "/" + "pickles"

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    num_train_clients = int(args.client_num_in_total * args.fraction)
    num_test_clients = args.client_num_in_total - num_train_clients

    transform = transforms.Compose(
        [transforms.Normalize(MEAN[args.dataset], STD[args.dataset]),]
    )
    target_transform = None
    trainset_stats = {}
    testset_stats = {}

    if not os.path.isdir(CURRENT_DIR + "/" + args.dataset):
        os.mkdir(CURRENT_DIR + "/" + args.dataset)
    if os.path.isdir(pickles_dir):
        os.system(f"rm -rf {pickles_dir}")
    os.mkdir(f"{pickles_dir}")

    ori_dataset, target_dataset = DATASET[args.dataset]
    trainset = ori_dataset(
        dataset_dir, train=True, download=True, transform=transforms.ToTensor()
    )
    testset = ori_dataset(dataset_dir, train=False, transform=transforms.ToTensor())

    num_classes = 10 if args.classes <= 0 else args.classes
    all_trainsets, trainset_stats = randomly_alloc_classes(
        ori_dataset=trainset,
        target_dataset=target_dataset,
        num_clients=num_train_clients,
        num_classes=num_classes,
        transform=transform,
        target_transform=target_transform,
    )
    all_testsets, testset_stats = randomly_alloc_classes(
        ori_dataset=testset,
        target_dataset=target_dataset,
        num_clients=num_test_clients,
        num_classes=num_classes,
        transform=transform,
        target_transform=target_transform,
    )

    all_datasets = all_trainsets + all_testsets

    for client_id, dataset in enumerate(all_datasets):
        with open(pickles_dir + "/" + str(client_id) + ".pkl", "wb") as f:
            pickle.dump(dataset, f)
    with open(pickles_dir + "/" + "seperation.pkl", "wb") as f:
        pickle.dump(
            {
                "train": [i for i in range(num_train_clients)],
                "test": [i for i in range(num_train_clients, args.client_num_in_total)],
                "total": args.client_num_in_total,
            },
            f,
        )
    with open(dataset_dir + "/" + "all_stats.json", "w") as f:
        json.dump({"train": trainset_stats, "test": testset_stats}, f)

def randomly_alloc_classes(
    ori_dataset: Dataset,
    target_dataset: Dataset,
    num_clients: int,
    num_classes: int,
    transform=None,
    target_transform=None,
) -> Tuple[List[Dataset], Dict[str, Dict[str, int]]]:
    dict_users = noniid_slicing(ori_dataset, num_clients, num_clients * num_classes)
    stats = {}
    for i, indices in dict_users.items():
        targets_numpy = np.array(ori_dataset.targets)
        stats[f"client {i}"] = {"x": 0, "y": {}}
        stats[f"client {i}"]["x"] = len(indices)
        stats[f"client {i}"]["y"] = Counter(targets_numpy[indices].tolist())
    datasets = []
    for indices in dict_users.values():
        datasets.append(
            target_dataset(
                [ori_dataset[i] for i in indices],
                transform=transform,
                target_transform=target_transform,
            )
        )
    return datasets, stats

class get_args_preprocess():
  def __init__(self):
    self.dataset = "mnist"
    self.client_num_in_total = 200
    self.fraction = 0.9
    self.classes = 2
    self.seed = 0

args = get_args_preprocess()
preprocess(args)

###DATA/UTILS

In [105]:
DATASET_DICT = {
    "mnist": MNISTDataset,
    "cifar": CIFARDataset,
}
#CURRENT_DIR = Path(__file__).parent.abspath()
CURRENT_DIR = "/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg"

def get_dataloader(dataset: str, client_id: int, batch_size=20, valset_ratio=0.1):
    pickles_dir = CURRENT_DIR + "/" + dataset + "/" "pickles"
    if os.path.isdir(pickles_dir) is False:
        raise RuntimeError("Please preprocess and create pickles first.")

    with open(pickles_dir + "/" + str(client_id) + ".pkl", "rb") as f:
        client_dataset: DATASET_DICT[dataset] = pickle.load(f)

    val_num_samples = int(valset_ratio * len(client_dataset))
    train_num_samples = len(client_dataset) - val_num_samples

    trainset, valset = random_split(
        client_dataset, [train_num_samples, val_num_samples]
    )
    trainloader = DataLoader(trainset, batch_size, drop_last=True)
    valloader = DataLoader(valset, batch_size)

    return trainloader, valloader

def get_client_id_indices(dataset):
    dataset_pickles_path = CURRENT_DIR + "/" + dataset + "/" + "pickles"
    with open(dataset_pickles_path + "/" + "seperation.pkl", "rb") as f:
        seperation = pickle.load(f)
    return (seperation["train"], seperation["test"], seperation["total"])

###MODEL

In [106]:
class elu(nn.Module):
    def __init__(self) -> None:
        super(elu, self).__init__()

    def forward(self, x):
        return torch.where(x >= 0, x, 0.2 * (torch.exp(x) - 1))


class linear(nn.Module):
    def __init__(self, in_c, out_c) -> None:
        super(linear, self).__init__()
        self.w = nn.Parameter(
            torch.randn(out_c, in_c) * torch.sqrt(torch.tensor(2 / in_c))
        )
        self.b = nn.Parameter(torch.randn(out_c))

    def forward(self, x):
        return F.linear(x, self.w, self.b)


class MLP_MNIST(nn.Module):
    def __init__(self) -> None:
        super(MLP_MNIST, self).__init__()
        self.fc1 = linear(28 * 28, 80)
        self.fc2 = linear(80, 60)
        self.fc3 = linear(60, 10)
        self.flatten = nn.Flatten()
        self.activation = elu()

    def forward(self, x):
        x = self.flatten(x)

        x = self.fc1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        return x


class MLP_CIFAR10(nn.Module):
    def __init__(self) -> None:
        super(MLP_CIFAR10, self).__init__()
        self.fc1 = linear(32 * 32 * 3, 80)
        self.fc2 = linear(80, 60)
        self.fc3 = linear(60, 10)
        self.flatten = nn.Flatten()
        self.activation = elu()

    def forward(self, x):
        x = self.flatten(x)

        x = self.fc1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        return x

MODEL_DICT = {"mnist": MLP_MNIST, "cifar": MLP_CIFAR10}

def get_model(dataset, device):
    return MODEL_DICT[dataset]().to(device)

###UTILS

In [107]:
class get_args():
  def __init__(self):
    self.alpha = 1e-2
    self.beta = 1e-3
    self.global_epochs = 200
    self.local_epochs = 4
    self.pers_epochs = 1
    self.hf = 1
    self.batch_size = 40
    self.valset_ratio = 0.1
    self.dataset = "mnist"  #choices = ["mnist", "cifar"]
    self.client_num_per_round = 10
    self.seed = 17
    self.gpu = 1
    self.eval_while_training = 1
    self.log = 0
    self.mix = 1
    
# def get_args():
#     parser = ArgumentParser()
#     parser.add_argument("--alpha", type=float, default=1e-2)
#     parser.add_argument("--beta", type=float, default=1e-3)
#     parser.add_argument("--global_epochs", type=int, default=200)
#     parser.add_argument("--local_epochs", type=int, default=4)
#     parser.add_argument(
#         "--pers_epochs",
#         type=int,
#         default=1,
#         help="Indicate how many data batches would be used for personalization. Negatives means that equal to train phase.",
#     )
#     parser.add_argument(
#         "--hf",
#         type=int,
#         default=1,
#         help="0 for performing Per-FedAvg(FO), others for Per-FedAvg(HF)",
#     )
#     parser.add_argument("--batch_size", type=int, default=40)
#     parser.add_argument(
#         "--valset_ratio",
#         type=float,
#         default=0.1,
#         help="Proportion of val set in the entire client local dataset",
#     )
#     parser.add_argument(
#         "--dataset", type=str, choices=["mnist", "cifar"], default="mnist"
#     )
#     parser.add_argument("--client_num_per_round", type=int, default=10)
#     parser.add_argument("--seed", type=int, default=17)
#     parser.add_argument(
#         "--gpu",
#         type=int,
#         default=1,
#         help="Non-zero value for using gpu, 0 for using cpu",
#     )
#     parser.add_argument(
#         "--eval_while_training",
#         type=int,
#         default=1,
#         help="Non-zero value for performing local evaluation before and after local training",
#     )
#     parser.add_argument("--log", type=int, default=0)
#     return parser.parse_args()


@torch.no_grad()
def eval(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: Union[torch.nn.MSELoss, torch.nn.CrossEntropyLoss],
    device=torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor]:
    model.eval()
    total_loss = 0
    num_samples = 0
    acc = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logit = model(x)
        # total_loss += criterion(logit, y) / y.size(-1)
        total_loss += criterion(logit, y)
        pred = torch.softmax(logit, -1).argmax(-1)
        acc += torch.eq(pred, y).int().sum()
        num_samples += y.size(-1)
    model.train()
    return total_loss, acc / num_samples

def get_data_batch(
    dataloader: torch.utils.data.DataLoader,
    iterator: Iterator,
    device=torch.device("cpu"),
):
    try:
        x, y = next(iterator)
    except StopIteration:
        iterator = iter(dataloader)
        x, y = next(iterator)

    return x.to(device), y.to(device)

def fix_random_seed(seed: int):
    torch.cuda.empty_cache()
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

###Per-FedAvg

In [108]:
from torch.distributions import Beta

class ClientInterpolation:
    def __init__(self):
        self.dist = Beta(torch.FloatTensor([2]), torch.FloatTensor([2]))

    def rand_bbox(self, size, lam):
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam.cpu())
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        # uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

    def mixup_data(self, x_support, x_query, lam):
        mixed_x = x_query.clone()
        bbx1, bby1, bbx2, bby2 = self.rand_bbox(x_query.size(), lam)

        mixed_x[:, :, bbx1:bbx2, bby1:bby2] = x_support[:, :, bbx1:bbx2, bby1:bby2]

        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x_query.size()[-1] * x_query.size()[-2]))

        return mixed_x, lam        

    # def client_crossmix(self, x1s, y1s, x1q, y1q, x2s, y2s, x2q, y2q):
        
    #     return None

In [109]:
class PerFedAvgClient(ClientInterpolation):
    def __init__(
        self,
        client_id: int,
        alpha: float,
        beta: float,
        global_model: torch.nn.Module,
        criterion: Union[torch.nn.CrossEntropyLoss, torch.nn.MSELoss],
        batch_size: int,
        dataset: str,
        local_epochs: int,
        valset_ratio: float,
        # qs_ratio: float 
        logger: rich.console.Console,
        gpu: int,
    ):
        if gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.logger = logger
        self.dataset = dataset
        self.valset_ratio = valset_ratio
        self.local_epochs = local_epochs
        self.criterion = criterion
        self.id = client_id
        self.model = deepcopy(global_model)
        self.alpha = alpha                  # inner loop learning rate
        self.beta = beta                    # meta  loop learning rate
        self.batch_size = batch_size
        self.o_trainloader, self.valloader = get_dataloader(
            dataset, client_id, batch_size, valset_ratio
        )
        self.cr_trainloader = None
        self.dist = Beta(torch.FloatTensor([2]), torch.FloatTensor([2]))
        self.iter_trainloader = iter(self.o_trainloader)
        self.num_classes = 10

    def train(
        self,
        global_model: torch.nn.Module,
        hessian_free=False,
        eval_while_training=False,
    ):
        self.model.load_state_dict(global_model.state_dict())
        if eval_while_training:
            loss_before, acc_before = eval(
                self.model, self.valloader, self.criterion, self.device
            )
        self._train(hessian_free)

        if eval_while_training:
            loss_after, acc_after = eval(
                self.model, self.valloader, self.criterion, self.device
            )
            self.logger.log(
                "client [{}] [red]loss: {:.4f} -> {:.4f}   [blue]acc: {:.2f}% -> {:.2f}%".format(
                    self.id,
                    loss_before,
                    loss_after,
                    acc_before * 100.0,
                    acc_after * 100.0,
                )
            )
        return SerializationTool.serialize_model(self.model)

    def _train(self, hessian_free=False):
        if args.mix:
            self.trainloader = self.o_trainloader

        else: 
            self.trainloader = self.cr_trainloader

        if hessian_free:  # Per-FedAvg(HF)
            for _ in range(self.local_epochs):
                temp_model = deepcopy(self.model)
                # 1st inner-loop
                data_batch_1 = get_data_batch(
                    self.trainloader, self.iter_trainloader, self.device
                )
                grads = self.compute_grad(temp_model, data_batch_1)
                for param, grad in zip(temp_model.parameters(), grads):
                    param.data.sub_(self.alpha * grad)
                # 2nd 
                data_batch_2 = get_data_batch(
                    self.trainloader, self.iter_trainloader, self.device
                )
                grads_1st = self.compute_grad(temp_model, data_batch_2)
                # 3rd 
                data_batch_3 = get_data_batch(
                    self.trainloader, self.iter_trainloader, self.device
                )

                grads_2nd = self.compute_grad(
                    self.model, data_batch_3, v=grads_1st, second_order_grads=True
                )
                for param, grad1, grad2 in zip(
                    self.model.parameters(), grads_1st, grads_2nd
                ):
                    param.data.sub_(self.beta * grad1 - self.beta * self.alpha * grad2)

        else:  # Per-FedAvg(FO)
            for _ in range(self.local_epochs):
                # ========================== FedAvg ==========================
                # NOTE: You can uncomment those codes for running FedAvg.
                #       When you're trying to run FedAvg, comment other codes in this branch.

                # data_batch = utils.get_data_batch(
                #     self.trainloader, self.iter_trainloader, self.device
                # )
                # grads = self.compute_grad(self.model, data_batch)
                # for param, grad in zip(self.model.parameters(), grads):
                #     param.data.sub_(self.beta * grad)

                # ============================================================

                temp_model = deepcopy(self.model)
                data_batch_1 = get_data_batch(
                    self.trainloader, self.iter_trainloader, self.device
                )
                grads = self.compute_grad(temp_model, data_batch_1)

                for param, grad in zip(temp_model.parameters(), grads):
                    param.data.sub_(self.alpha * grad)

                data_batch_2 = get_data_batch(
                    self.trainloader, self.iter_trainloader, self.device
                )
                grads = self.compute_grad(temp_model, data_batch_2)

                for param, grad in zip(self.model.parameters(), grads):
                    param.data.sub_(self.beta * grad)

    def compute_grad(
        self,
        model: torch.nn.Module,
        data_batch: Tuple[torch.Tensor, torch.Tensor],
        v: Union[Tuple[torch.Tensor, ...], None] = None,
        second_order_grads=False,
    ):
        x, y = data_batch
        if second_order_grads:
            frz_model_params = deepcopy(model.state_dict())
            delta = 1e-3
            dummy_model_params_1 = OrderedDict()
            dummy_model_params_2 = OrderedDict()
            with torch.no_grad():
                for (layer_name, param), grad in zip(model.named_parameters(), v):
                    dummy_model_params_1.update({layer_name: param + delta * grad})
                    dummy_model_params_2.update({layer_name: param - delta * grad})

            model.load_state_dict(dummy_model_params_1, strict=False)
            logit_1 = model(x)
            # loss_1 = self.criterion(logit_1, y) / y.size(-1)
            loss_1 = self.criterion(logit_1, y)
            grads_1 = torch.autograd.grad(loss_1, model.parameters())

            model.load_state_dict(dummy_model_params_2, strict=False)
            logit_2 = model(x)
            loss_2 = self.criterion(logit_2, y)
            # loss_2 = self.criterion(logit_2, y) / y.size(-1)
            grads_2 = torch.autograd.grad(loss_2, model.parameters())

            model.load_state_dict(frz_model_params)

            grads = []
            with torch.no_grad():
                for g1, g2 in zip(grads_1, grads_2):
                    grads.append((g1 - g2) / (2 * delta))
            return grads

        else:
            logit = model(x)
            # loss = self.criterion(logit, y) / y.size(-1)
            loss = self.criterion(logit, y)
            grads = torch.autograd.grad(loss, model.parameters())
            return grads

    def pers_N_eval(self, global_model: torch.nn.Module, pers_epochs: int):
        self.model.load_state_dict(global_model.state_dict())

        loss_before, acc_before = eval(
            self.model, self.valloader, self.criterion, self.device
        )
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.alpha)
        for _ in range(pers_epochs):
            x, y = get_data_batch(
                self.o_trainloader, self.iter_trainloader, self.device
            )
            logit = self.model(x)
            # loss = self.criterion(logit, y) / y.size(-1)
            loss = self.criterion(logit, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_after, acc_after = eval(
            self.model, self.valloader, self.criterion, self.device
        )
        self.logger.log(
            "client [{}] [red]loss: {:.4f} -> {:.4f}   [blue]acc: {:.2f}% -> {:.2f}%".format(
                self.id, loss_before, loss_after, acc_before * 100.0, acc_after * 100.0,
            )
        )
        return {
            "loss_before": loss_before,
            "acc_before": acc_before,
            "loss_after": loss_after,
            "acc_after": acc_after,
        }
    
    # We need to finish it tomorrow
    def client_crossmix(self, paired_id):
        '''
        pair_trainloader: x2q, y2q 
        pair_valloader:   x2s, y2s
        '''
        cr_list = []
        self.o_trainloader, self.valloader = get_dataloader(
            self.dataset, self.id, self.batch_size, self.valset_ratio
        )
        pair_trainloader, pair_valloader = get_dataloader(
            self.dataset, paired_id, self.batch_size, self.valset_ratio
        )
        iter_pair_trainloader = iter(pair_trainloader)

        lam_mix = self.dist.sample().to("cuda") # lam_mix = 1 -> mixed_representation = x2s

        # processing on different batches and then concatenate them
        for ep in range(len(pair_trainloader)):
            if ep == 1:
                break

            ori_b_x, ori_b_y = get_data_batch(
                    self.o_trainloader, self.iter_trainloader, self.device
                )
            old_b_x = ori_b_x.cpu().numpy()
            ori_b_x = ori_b_x.cpu().numpy()
            pair_b_x, pair_b_y = get_data_batch(
                    self.o_trainloader, iter_pair_trainloader, self.device
                )

            np.random.shuffle(ori_b_x)    

            # print(np.linalg.norm(ori_b_x - old_b_x))
            ori_b_x = torch.from_numpy(ori_b_x).to(self.device)

            x_mix_s, _ = self.mixup_data(ori_b_x, pair_b_x, lam_mix)

            for idx, x_mix_dat in enumerate(x_mix_s):
                cr_list.append(x_mix_dat)
        self.cr_trainloader = torch.stack(cr_list) \
                        if self.cr_trainloader == None \
                        else torch.cat([self.cr_trainloader,torch.stack(cr_list)])

###MAIN

In [110]:
def pairing_client(client_set):
    np.random.shuffle(client_set)
    client_set_1 = client_set[0:int(len(client_set)/2)]
    client_set_2 = client_set[int(len(client_set)/2):len(client_set)]
    selected_client_pairs = np.stack((client_set_1, client_set_2), axis=1)
    return selected_client_pairs


In [111]:
if __name__ == "__main__":
    args = get_args()
    fix_random_seed(args.seed)
    if os.path.isdir("./log") == False:
        os.mkdir("./log")
    if args.gpu and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    global_model = get_model(args.dataset, device)
    logger = Console(record=args.log)
    #logger.log(f"Arguments:", dict(args._get_kwargs()))
    clients_4_training, clients_4_eval, client_num_in_total = get_client_id_indices(
        args.dataset
    )

    # init clients 
    clients = [
        PerFedAvgClient(
            client_id=client_id,
            alpha=args.alpha,
            beta=args.beta,
            global_model=global_model,
            criterion=torch.nn.CrossEntropyLoss(),
            batch_size=args.batch_size,
            dataset=args.dataset,
            local_epochs=args.local_epochs,
            valset_ratio=args.valset_ratio,
            logger=logger,
            gpu=args.gpu,
        )
        for client_id in range(client_num_in_total)
    ]
    # training
    #logger.log("=" * 20, "TRAINING", "=" * 20, style="bold red")
    for _ in track(
        range(args.global_epochs), "Training...", console=logger, disable=args.log
    ):
        # select clients
        selected_clients = random.sample(clients_4_training, args.client_num_per_round)
        selected_client_pairs = pairing_client(selected_clients)

        model_params_cache = []
        # client local training
        for client_pair in selected_client_pairs:
            for client_idx, client_id in enumerate(client_pair):
                if args.mix:
                    mix_c = random.randint(0,3)   # we apply interpolation with prob 66.67%
                    if mix_c < 3:  # interpolation
                        # giving out interpolated representation -> self.trainloader
                        clients[client_id].client_crossmix(client_pair[int((client_idx+1) % 2)])
                        serialized_model_params = clients[client_id].train(
                            global_model=global_model,
                            hessian_free=args.hf,
                            eval_while_training=args.eval_while_training,
                        )                        
                    else:          # non-interpolation
                        # giving out interpolated representation -> self.trainloader
                        clients[client_id].client_crossmix(client_id)
                        serialized_model_params = clients[client_id].train(
                            global_model=global_model,
                            hessian_free=args.hf,
                            eval_while_training=args.eval_while_training,
                        )
                else:   # normal forward for both clients
                    serialized_model_params = clients[client_id].train(
                        global_model=global_model,
                        hessian_free=args.hf,
                        eval_while_training=args.eval_while_training,
                    )
                model_params_cache.append(serialized_model_params)

        # aggregate model parameters
        aggregated_model_params = Aggregators.fedavg_aggregate(model_params_cache)
        SerializationTool.deserialize_model(global_model, aggregated_model_params)
        #logger.log("=" * 60)
    # eval
    pers_epochs = args.local_epochs if args.pers_epochs == -1 else args.pers_epochs
    #logger.log("=" * 20, "EVALUATION", "=" * 20, style="bold blue")
    loss_before = []
    loss_after = []
    acc_before = []
    acc_after = []
    for client_id in track(
        clients_4_eval, "Evaluating...", console=logger, disable=args.log
    ):
        stats = clients[client_id].pers_N_eval(
            global_model=global_model, pers_epochs=pers_epochs,
        )
        loss_before.append(stats["loss_before"])
        loss_after.append(stats["loss_after"])
        acc_before.append(stats["acc_before"])
        acc_after.append(stats["acc_after"])

    #logger.log("=" * 20, "RESULTS", "=" * 20, style="bold green")
    #logger.log(f"loss_before_pers: {(sum(loss_before) / len(loss_before)):.4f}")
    #logger.log(f"acc_before_pers: {(sum(acc_before) * 100.0 / len(acc_before)):.2f}%")
    #logger.log(f"loss_after_pers: {(sum(loss_after) / len(loss_after)):.4f}")
    #logger.log(f"acc_after_pers: {(sum(acc_after) * 100.0 / len(acc_after)):.2f}%")

    if args.log:
        algo = "HF" if args.hf else "FO"
        # logger.save_html(
        #     f"./log/{args.dataset}_{args.client_num_per_round}_{args.global_epochs}_{pers_epochs}_{algo}.html"
        # )

Output()

Output()

In [112]:
import numpy as np
import torch
from torch.utils.data import Dataset
import pickle

class RainbowMNIST(Dataset):

    def __init__(self, args, mode):
        super(RainbowMNIST, self).__init__()
        self.args = args
        self.nb_classes = args.num_classes
        self.nb_samples_per_class = args.update_batch_size + args.update_batch_size_eval
        self.n_way = args.num_classes  # n-way
        self.k_shot = args.update_batch_size  # k-shot
        self.k_query = args.update_batch_size_eval  # for evaluation
        self.set_size = self.n_way * self.k_shot  # num of samples per set
        self.query_size = self.n_way * self.k_query  # number of samples per set for evaluation
        self.mode = mode
        self.data_file = '{}RainbowMNIST/rainbowmnist_all.pkl'.format(args.datadir)

        self.data = pickle.load(open(self.data_file, 'rb'))

        self.num_groupid = len(self.data.keys())

        for group_id in range(self.num_groupid):
            self.data[group_id]['labels'] = self.data[group_id]['labels'].reshape(10, 100)[:, :20]
            self.data[group_id]['images'] = self.data[group_id]['images'].reshape(10, 100, 28, 28, 3)[:, :20, ...]
            self.data[group_id]['images'] = torch.tensor(np.transpose(self.data[group_id]['images'], (0, 1, 4, 2, 3)))

        if self.mode == 'train':
            self.sel_group_id = np.array([49,  8, 19, 47, 25, 27, 42, 50, 24, 40,  3, 45,  6, 41,  2, 17, 14,
              10,  5, 26, 12, 33,  9, 11, 32, 54, 28,  7, 39, 51, 46, 44, 30, 13,
              18,  0, 34, 43, 52, 29])
            num_of_tasks = self.sel_group_id.shape[0]
            if self.args.ratio<1.0:
                num_of_tasks = int(num_of_tasks*self.args.ratio)
                self.sel_group_id = self.sel_group_id[:num_of_tasks]
        elif self.mode == 'val':
            self.sel_group_id = np.array([15, 16, 38, 36, 37,  4])
        elif self.mode == 'test':
            self.sel_group_id = np.array([35, 48, 23, 20, 22, 55,  1, 21, 31, 53])


    def __len__(self):
        return self.args.metatrain_iterations*self.args.meta_batch_size

    def __getitem__(self, index):
        self.classes_idx = np.arange(self.data[0]['images'].shape[0])
        self.samples_idx = np.arange(self.data[0]['images'].shape[1])

        support_x = torch.FloatTensor(torch.zeros((self.args.meta_batch_size, self.set_size, 3, 28, 28)))
        query_x   = torch.FloatTensor(torch.zeros((self.args.meta_batch_size, self.query_size, 3, 28, 28)))

        support_y = np.zeros([self.args.meta_batch_size, self.set_size])
        query_y   = np.zeros([self.args.meta_batch_size, self.query_size])

        for meta_batch_id in range(self.args.meta_batch_size):
            self.choose_group = np.random.choice(self.sel_group_id, size=1, replace=False).item()
            for j in range(10):
                np.random.shuffle(self.samples_idx)
                choose_samples = self.samples_idx[:self.nb_samples_per_class]
                support_x[meta_batch_id][j * self.k_shot:(j + 1) * self.k_shot] = self.data[self.choose_group]['images'][j, choose_samples[:self.k_shot], ...]
                query_x[meta_batch_id][j * self.k_query:(j + 1) * self.k_query] = self.data[self.choose_group]['images'][j, choose_samples[self.k_shot:], ...]
                support_y[meta_batch_id][j * self.k_shot:(j + 1) * self.k_shot] = j
                query_y[meta_batch_id][j * self.k_query:(j + 1) * self.k_query] = j

        return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

In [113]:
class arg_mlti():
    def __init__(self):
        self.datasource = 'rainbowmnist'
        self.num_classes = 10
        self.num_test_task = 600
        self.test_epoch = -1
        self.metatrain_iterations = 15000
        self.meta_batch_size = 25
        self.meta_lr = 0.001
        self.update_batch_size = 5
        self.update_batch_size_eval = 15
        self.num_filters = 64
        self.weight_decay = 0.0
        self.logdir = 'xxx'
        self.datadir = 'xxx'
        self.resume = 0
        self.train = 1
        self.mix = 0
        self.trial = 0
        self.ratio = 1.0

from google.colab import drive
drive.mount("/content/drive")
%cd '/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg'

args2 = arg_mlti()

import numpy as np
import torch
from torch.utils.data import Dataset
import pickle

class RainbowMNIST(Dataset):

    def __init__(self, args, mode):
        super(RainbowMNIST, self).__init__()
        self.args = args
        self.nb_classes = args.num_classes
        self.nb_samples_per_class = args.update_batch_size + args.update_batch_size_eval
        self.n_way = args.num_classes  # n-way
        self.k_shot = args.update_batch_size  # k-shot
        self.k_query = args.update_batch_size_eval  # for evaluation
        self.set_size = self.n_way * self.k_shot  # num of samples per set
        self.query_size = self.n_way * self.k_query  # number of samples per set for evaluation
        self.mode = mode
        self.data_file = '{}RainbowMNIST/rainbowmnist_all.pkl'.format(args.datadir)

        self.data = pickle.load(open(self.data_file, 'rb'))

        self.num_groupid = len(self.data.keys())

        for group_id in range(self.num_groupid):
            self.data[group_id]['labels'] = self.data[group_id]['labels'].reshape(10, 100)[:, :20]
            self.data[group_id]['images'] = self.data[group_id]['images'].reshape(10, 100, 28, 28, 3)[:, :20, ...]
            self.data[group_id]['images'] = torch.tensor(np.transpose(self.data[group_id]['images'], (0, 1, 4, 2, 3)))

        if self.mode == 'train':
            self.sel_group_id = np.array([49,  8, 19, 47, 25, 27, 42, 50, 24, 40,  3, 45,  6, 41,  2, 17, 14,
           10,  5, 26, 12, 33,  9, 11, 32, 54, 28,  7, 39, 51, 46, 44, 30, 13,
           18,  0, 34, 43, 52, 29])
            num_of_tasks = self.sel_group_id.shape[0]
            if self.args.ratio<1.0:
                num_of_tasks = int(num_of_tasks*self.args.ratio)
                self.sel_group_id = self.sel_group_id[:num_of_tasks]
        elif self.mode == 'val':
            self.sel_group_id = np.array([15, 16, 38, 36, 37,  4])
        elif self.mode == 'test':
            self.sel_group_id = np.array([35, 48, 23, 20, 22, 55,  1, 21, 31, 53])


    def __len__(self):
        return self.args.metatrain_iterations*self.args.meta_batch_size

    def __getitem__(self, index):
        self.classes_idx = np.arange(self.data[0]['images'].shape[0])
        self.samples_idx = np.arange(self.data[0]['images'].shape[1])

        support_x = torch.FloatTensor(torch.zeros((self.args.meta_batch_size, self.set_size, 3, 28, 28)))
        query_x = torch.FloatTensor(torch.zeros((self.args.meta_batch_size, self.query_size, 3, 28, 28)))

        support_y = np.zeros([self.args.meta_batch_size, self.set_size])
        query_y = np.zeros([self.args.meta_batch_size, self.query_size])


        for meta_batch_id in range(self.args.meta_batch_size):
            self.choose_group = np.random.choice(self.sel_group_id, size=1, replace=False).item()
            for j in range(10):
                np.random.shuffle(self.samples_idx)
                choose_samples = self.samples_idx[:self.nb_samples_per_class]
                support_x[meta_batch_id][j * self.k_shot:(j + 1) * self.k_shot] = self.data[self.choose_group]['images'][j, choose_samples[:self.k_shot], ...]
                query_x[meta_batch_id][j * self.k_query:(j + 1) * self.k_query] = self.data[self.choose_group]['images'][j, choose_samples[
                            self.k_shot:], ...]
                support_y[meta_batch_id][j * self.k_shot:(j + 1) * self.k_shot] = j
                query_y[meta_batch_id][j * self.k_query:(j + 1) * self.k_query] = j

        return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

dataloader = RainbowMNIST(args2, 'train')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/Shareddrives/Duong-DatDeakin/Personalized_FedAvg


In [114]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam.cpu())
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def mixup_data(xs, xq, lam):
    mixed_x = xq.clone()
    bbx1, bby1, bbx2, bby2 = rand_bbox(xq.size(), lam)

    mixed_x[:, :, bbx1:bbx2, bby1:bby2] = xs[:, :, bbx1:bbx2, bby1:bby2]

    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (xq.size()[-1] * xq.size()[-2]))

    return mixed_x, lam

In [115]:
test = False

class get_args():
  def __init__(self):
    self.alpha = 1e-2
    self.beta = 1e-3
    self.global_epochs = 200
    self.local_epochs = 4
    self.pers_epochs = 1
    self.hf = 1
    self.batch_size = 40
    self.valset_ratio = 0.1
    self.dataset = "mnist"  #choices = ["mnist", "cifar"]
    self.client_num_per_round = 10
    self.seed = 17
    self.gpu = 1
    self.eval_while_training = 1
    self.log = 0
    self.mix = 1
    
args = get_args()

dist = Beta(torch.FloatTensor([2]), torch.FloatTensor([2]))

o_trainloader, valloader = get_dataloader(
            args.dataset, int(0), args.batch_size, args.valset_ratio
        )
iter_trainloader = iter(o_trainloader)

pair_trainloader, pair_valloader = get_dataloader(
    args.dataset, int(1), args.batch_size, args.valset_ratio
)
iter_ptrainloader = iter(pair_trainloader)

lam_mix = dist.sample().to("cuda") # lam_mix = 1 -> mixed_representation = x2s

cr_trainloader = None
cr_list = []

print(len(o_trainloader))

# processing on different batches and then concatenate them
# for ep in range(args.global_epochs):
#     if ep == 1:
#         break

#     ori_b_x, ori_b_y = get_data_batch(
#             o_trainloader, iter_trainloader, device
#         )
#     old_b_x = ori_b_x.cpu().numpy()
#     ori_b_x = ori_b_x.cpu().numpy()
#     pair_b_x, pair_b_y = get_data_batch(
#             o_trainloader, iter_ptrainloader, device
#         )
#     # pair_b_x = pair_b_x.cpu().numpy()
#     np.random.shuffle(ori_b_x)    

#     # print(np.linalg.norm(ori_b_x - old_b_x))
#     ori_b_x = torch.from_numpy(ori_b_x).to(device)
#     # print(ori_b_x.size())
#     # print(pair_b_x.size())
#     # print(ori_b_x)
#     # print(pair_b_x)
#     x_mix_s, _ = mixup_data(ori_b_x, pair_b_x, lam_mix)
#     # if test == False: 
#     #     x_mix_np = x_mix_s.cpu().numpy()
#     #     x_ori_np = ori_b_x.cpu().numpy()
#     #     x_pair_np = pair_b_x.cpu().numpy()
#     #     for idx, x_mix_dat in enumerate(x_mix_s):
#     #         print(f"data point: {x_mix_dat.size()} - {x_mix_dat}")

#     for idx, x_mix_dat in enumerate(x_mix_s):
#         cr_list.append(x_mix_dat)
#     cr_trainloader = torch.stack(cr_list) \
#                     if cr_trainloader == None \
#                     else torch.cat([cr_trainloader,torch.stack(cr_list)])
    
#     print(len(cr_trainloader))
#     for idx, cr_traindat in enumerate(cr_trainloader):
#         print(cr_traindat)

7


In [116]:
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(dataloader):
    # print(x_spt.size())
    # print(x_spt.squeeze(0).size())
    # print(x_spt.squeeze(0)[0].size())
    print(x_spt.squeeze(0)[0])
    if step == 0:
        break

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980],
          [0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980],
          [0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980],
          ...,
          [0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980],
          [0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980],
          [0.4980, 0.4980, 0.4980,  ..., 0.4980, 0.4980, 0.4980]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0