FedPCG:一种利用聚类采样和全局原型的个性化联邦学习方法,基于隐语框架和FedNH baseline实现, reference:Dai, Y., Chen, Z., Li, J., Heinecke, S., Sun, L., & Xu, R. (2023, June). Tackling data heterogeneity in federated learning with class prototypes. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 37, No. 6, pp. 7314-7322). https://github.com/Yutong-Dai/FedNH

In [None]:
import secretflow as sf
from secretflow import PYUObject, proxy
import os
import yaml
import sys
import argparse
import time
import pickle
import warnings
import random
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from itertools import compress, product
from PIL import Image
from collections import OrderedDict, Counter
from tqdm import tqdm
from copy import deepcopy
from scipy.cluster.hierarchy import fcluster, linkage

global wandb_installed
try:
    import wandb

    wandb_installed = True
except ModuleNotFoundError:
    wandb_installed = False
print(torch.__version__)

一些超参数设置，可以通过同级目录下的config.ini设置对应的超参数

In [None]:
import argparse
import configparser


def args_parser():
    parser = argparse.ArgumentParser()

    parser = argparse.ArgumentParser(description="Test Algorithms.")
    # general settings
    parser.add_argument(
        "--purpose", default="experiments", type=str, help="purpose of this run"
    )
    parser.add_argument("--device", default="cuda:1", type=str, help="cuda device")
    parser.add_argument(
        "--global_seed", default=2022, type=int, help="Global random seed."
    )
    parser.add_argument(
        "--use_wandb",
        default=True,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="Use wandb pkg",
    )
    parser.add_argument(
        "--keep_clients_model",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="Keep FedAVG local model",
    )
    # model architecture
    parser.add_argument(
        "--no_norm",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="Use group/batch norm or not",
    )
    # optimizer
    parser.add_argument("--optimizer", default="SGD", type=str, help="Optimizer")
    parser.add_argument("--num_epochs", default=5, type=int, help="num local epochs")
    parser.add_argument(
        "--client_lr", default=0.1, type=float, help="client side initial learning rate"
    )
    parser.add_argument(
        "--client_lr_scheduler",
        default="stepwise",
        type=str,
        help="client side learning rate update strategy",
    )
    parser.add_argument("--sgd_momentum", default=0.0, type=float, help="sgd momentum")
    parser.add_argument(
        "--sgd_weight_decay", default=1e-5, type=float, help="sgd weight decay"
    )
    parser.add_argument(
        "--use_sam",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="Use SAM optimizer",
    )
    # server config
    parser.add_argument(
        "--yamlfile", default=None, type=str, help="Configuration file."
    )
    parser.add_argument("--strategy", default=None, type=str, help="strategy FL")
    parser.add_argument(
        "--num_clients", default=100, type=int, help="number of clients"
    )
    parser.add_argument(
        "--num_rounds", default=200, type=int, help="number of communication rounds"
    )
    parser.add_argument(
        "--participate_ratio", default=0.1, type=float, help="participate ratio"
    )
    parser.add_argument(
        "--partition", default=None, type=str, help="method for partition the dataset"
    )
    parser.add_argument(
        "--beta", default=None, type=str, help="Dirichlet Distribution parameter"
    )
    parser.add_argument(
        "--num_classes_per_client",
        default=None,
        type=int,
        help="pathological non-iid parameter",
    )
    parser.add_argument(
        "--num_shards_per_client",
        default=None,
        type=int,
        help="pathological non-iid parameter fedavg simulation",
    )

    # strategy parameters
    parser.add_argument(
        "--FedNH_smoothing", default=0.9, type=float, help="moving average parameters"
    )
    parser.add_argument(
        "--FedNH_server_adv_prototype_agg",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="FedNH server adv agg",
    )
    parser.add_argument(
        "--FedNH_client_adv_prototype_agg",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="FedNH client adv agg",
    )

    parser.add_argument(
        "--FedROD_hyper_clf",
        default=True,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="FedRod phead uses hypernetwork",
    )
    parser.add_argument(
        "--FedROD_phead_separate",
        default=False,
        type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
        help="FedROD phead separate train",
    )
    parser.add_argument(
        "--FedProto_lambda",
        default=0.1,
        type=float,
        help="FedProto local penalty lambda",
    )
    parser.add_argument(
        "--FedRep_head_epochs",
        default=10,
        type=int,
        help="FedRep local epochs to update head",
    )
    parser.add_argument(
        "--FedBABU_finetune_epoch",
        default=5,
        type=int,
        help="FedBABU local epochs to finetune",
    )
    parser.add_argument(
        "--Ditto_lambda", default=0.75, type=float, help="penalty parameter for Ditto"
    )
    parser.add_argument(
        "--CReFF_num_of_fl_feature",
        default=100,
        type=int,
        help="num of federated feature per class",
    )
    parser.add_argument(
        "--CReFF_match_epoch",
        default=100,
        type=int,
        help="epoch used to minmize gradient matching loss",
    )
    parser.add_argument(
        "--CReFF_crt_epoch",
        default=300,
        type=int,
        help="epoch used to retrain classifier",
    )
    parser.add_argument("--CReFF_lr_net", default=0.01, type=float, help="lr for head")
    parser.add_argument(
        "--CReFF_lr_feature", default=0.1, type=float, help="lr for feature"
    )

    arg_list = None

    config = configparser.ConfigParser()
    config.read("config.ini")
    # 其实是个字典:
    # print(config['train']['batch_size'])
    arg_list = []
    for k, v in config["train"].items():
        arg_list.append("--" + k)
        arg_list.append(v)

    args = parser.parse_args(arg_list)
    return args


args = args_parser()
print(args)

随机种子设置，将args传入的部分参数覆盖config中的参数。

In [None]:
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


use_wandb = wandb_installed and args.use_wandb
setup_seed(args.global_seed)

with open(args.yamlfile, "r", encoding="utf-8") as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

# parse the default setting
server_config = config["server_config"]
client_config = config["client_config"]

# overwrite with inputs
server_config["strategy"] = args.strategy
server_config["num_clients"] = args.num_clients
server_config["num_rounds"] = args.num_rounds
server_config["participate_ratio"] = args.participate_ratio
server_config["partition"] = args.partition
server_config["beta"] = args.beta
server_config["num_classes_per_client"] = args.num_classes_per_client
server_config["num_shards_per_client"] = args.num_shards_per_client
client_config["num_rounds"] = args.num_rounds
client_config["global_seed"] = args.global_seed
client_config["optimizer"] = args.optimizer
client_config["client_lr"] = args.client_lr
client_config["client_lr_scheduler"] = args.client_lr_scheduler
client_config["sgd_momentum"] = args.sgd_momentum
client_config["sgd_weight_decay"] = args.sgd_weight_decay
client_config["use_sam"] = args.use_sam
client_config["no_norm"] = args.no_norm

if server_config["partition"] == "noniid-label-distribution":
    partition_arg = f"beta:{args.beta}"
elif server_config["partition"] == "noniid-label-quantity":
    partition_arg = f"num_classes_per_client:{args.num_classes_per_client}"
elif server_config["partition"] == "shards":
    partition_arg = f"num_shards_per_client:{args.num_shards_per_client}"
else:
    raise ValueError("not implemented partition")
print(server_config)
print(client_config)

客户端基类，主要用于联邦学习环境下各客户端的method

In [None]:
class Client:
    def __init__(self, criterion, trainset, testset, client_config, cid, **kwargs):
        autoassign(locals())
        if trainset is not None:
            self.num_train_samples = len(trainset)
        else:
            self.num_train_samples = 0
        if testset is not None:
            self.num_test_samples = len(testset)
        else:
            self.num_test_samples = 0

        if not torch.cuda.is_available():
            self.device = "cpu"
            print("cuda is not available. use cpu instead.")
        # wrap the trainset and testset with dataloader
        self._prepare_data()
        # local stats
        self.num_rounds_particiapted = 0
        self.train_loss_dict = OrderedDict()
        self.train_acc_dict = OrderedDict()
        self.test_loss_dict = OrderedDict()
        self.test_acc_dict = OrderedDict()
        self.new_state_dict = None

    def _prepare_data(self):
        self.label_dist = None
        train_batchsize = min(self.client_config["batch_size"], self.num_train_samples)
        test_batchsize = min(
            self.client_config["batch_size"] * 2, self.num_test_samples
        )

        if self.num_train_samples > 0:
            self.trainloader = DataLoader(
                self.trainset, batch_size=train_batchsize, shuffle=True
            )
            # summarize training set label distribution
            self.count_by_class = Counter(self.trainset.targets.numpy())
            self.label_dist = {
                i: self.count_by_class[i] / self.num_train_samples
                for i in sorted(self.count_by_class.keys())
            }
        else:
            self.trainloader = None

        if self.num_test_samples > 0:
            self.testloader = DataLoader(
                self.testset, batch_size=test_batchsize, shuffle=False
            )
            self.count_by_class_test = Counter(self.testset.targets.numpy())
            self.label_dist_test = {
                i: self.count_by_class_test[i] / self.num_test_samples
                for i in sorted(self.count_by_class_test.keys())
            }
        else:
            self.testloader = None

    def set_params(self, model_state_dict, exclude_keys):
        self.model.set_params(model_state_dict, exclude_keys)

    def get_params(self):
        return self.model.get_params()

    def get_params_values(self):
        return self.model.get_params_values()

    def get_model_parameters(self):
        return list(self.model.get_parameters())

    def get_model(self):
        self.model.eval()
        return self.model

    def get_grads(self, dataloader):
        return self.model.get_grads(dataloader)

    def initialize_model(self):
        raise NotImplementedError(
            "Please write a method for the client to initialize the model(s)."
        )

    def training(self, round, num_epochs):
        raise NotImplementedError("Please write a training method for the client.")

    def testing(self, round, testloader=None):
        """
        Provide testloader if one wants to use the externel testing dataset.
        """
        raise NotImplementedError("Please write a testing method for the client.")

    def upload(self):
        """
        Decide what information to share with the server
        """
        raise NotImplementedError

服务器基类，创建一个fake客户端用于全局模型性能测试。

In [None]:
fake_client_pyu = sf.PYU("fake_client")


class Server:
    def __init__(self, server_config, clients_dict, **kwargs):
        """ """
        autoassign(locals())
        self.server_model_state_dict = None
        self.server_model_state_dict_best_so_far = None
        self.num_clients = len(self.clients_dict)
        self.strategy = None
        self.average_train_loss_dict = {}
        self.average_train_acc_dict = {}
        # global model performance
        self.gfl_test_loss_dict = {}
        self.gfl_test_acc_dict = {}
        # local model performance (averaged across all clients)
        self.average_pfl_test_loss_dict = {}
        self.average_pfl_test_acc_dict = {}
        self.active_clients_indicies = None
        self.rounds = 0
        #         # create a fake client on the server side; use for testing the performance of the global model
        #         # trainset is only used for creating the label distribution
        self.server_side_client = kwargs["client_cstr"](
            kwargs["server_side_criterion"],
            kwargs["global_trainset"],
            kwargs["global_testset"],
            kwargs["server_side_client_config"],
            -1,
            device=fake_client_pyu,
            **kwargs,
        )

    def select_clients(self, ratio):
        assert (
            ratio > 0.0
        ), "Invalid ratio. Possibly the server_config['participate_ratio'] is wrong."
        num_clients = int(ratio * self.num_clients)
        selected_indices = np.random.choice(
            range(self.num_clients), num_clients, replace=False
        )
        return selected_indices

    def testing(self, round, active_only, **kwargs):
        raise NotImplementedError

    def collect_stats(self, stage, round, active_only, **kwargs):
        raise NotImplementedError()

    def aggregate(self, client_uploads, round):
        raise NotImplementedError

    def run(self):
        raise NotImplementedError

    def save(self, filename, keep_clients_model=False):
        if not keep_clients_model:
            for client in self.clients_dict.values():
                client.model = None
                client.trainloader = None
                client.trainset = None
                client.new_state_dict = None
        self.server_side_client.trainloader = None
        self.server_side_client.trainset = None
        self.server_side_client.testloader = None
        self.server_side_client.testset = None
        save_to_pkl(self, filename)

    def summary_setup(self):
        info = "=" * 30 + "Run Summary" + "=" * 30
        info += "\nDataset:\n"
        info += f" dataset:{self.server_config['dataset']} | num_classes:{self.server_config['num_classes']}"
        partition = self.server_config["partition"]
        info += f" | partition:{self.server_config['partition']}"
        if partition == "iid-equal-size":
            info += "\n"
        elif partition in ["iid-diff-size", "noniid-label-distribution"]:
            info += f" | beta:{self.server_config['beta']}\n"
        elif partition == "noniid-label-quantity":
            info += f" | num_classes_per_client:{self.server_config['num_classes_per_client']}\n "
        else:
            if "shards" in partition.split("-"):
                pass
            else:
                raise ValueError(f" Invalid dataset partition strategy:{partition}!")
        info += "Server Info:\n"
        info += f" strategy:{self.server_config['strategy']} | num_clients:{self.server_config['num_clients']} | num_rounds: {self.server_config['num_rounds']}"
        info += f" | participate_ratio:{self.server_config['participate_ratio']} | drop_ratio:{self.server_config['drop_ratio']}\n"
        info += f"Clients Info:\n"
        info += f" model:{client_config['model']} | num_epochs:{client_config['num_epochs']} | batch_size:{client_config['batch_size']}"
        info += f" | optimizer:{client_config['optimizer']} | inint lr:{client_config['client_lr']} | lr scheduler:{client_config['client_lr_scheduler']} | momentum: {client_config['sgd_momentum']} | weight decay: {client_config['sgd_weight_decay']}"
        print(info)

    def summary_result(self):
        raise NotImplementedError

分类任务模型架构，对模型架构进行了解耦。

In [None]:
class Model(nn.Module):
    """For classification problem"""

    def __init__(self, config):
        super().__init__()
        self.config = config

    def get_params(self):
        return self.state_dict()

    def get_parameters(self):
        return self.parameters()

    def get_gradients(self, dataloader):
        raise NotImplementedErrorm

    def set_params(self, model_state_dict, exclude_keys=set()):
        """
        Reference: Be careful with the state_dict[key].
        https://discuss.pytorch.org/t/how-to-copy-a-modified-state-dict-into-a-models-state-dict/64828/4.
        """
        with torch.no_grad():
            for key in model_state_dict.keys():
                if key not in exclude_keys:
                    self.state_dict()[key].copy_(model_state_dict[key])


class ModelWrapper(Model):
    def __init__(self, base, head, config):
        """
        head and base should be nn.module
        """
        super(ModelWrapper, self).__init__(config)

        self.base = base
        self.head = head

    def forward(self, x, return_embedding):
        feature_embedding = self.base(x)
        out = self.head(feature_embedding)
        if return_embedding:
            return feature_embedding, out
        else:
            return out


class Conv2Cifar(Model):
    def __init__(self, config):
        super().__init__(config)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.linear1 = nn.Linear(64 * 5 * 5, 384)
        self.linear2 = nn.Linear(384, 192)
        # intentionally remove the bias term for the last linear layer for fair comparison
        self.prototype = nn.Linear(192, config["num_classes"], bias=False)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        logits = self.prototype(x)
        return logits

    def get_embedding(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        logits = self.prototype(x)
        return x, logits


class Conv2CifarNH(Model):
    def __init__(self, config):
        super().__init__(config)
        self.return_embedding = config["FedNH_return_embedding"]
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.linear1 = nn.Linear(64 * 5 * 5, 384)
        self.linear2 = nn.Linear(384, 192)
        temp = nn.Linear(192, config["num_classes"], bias=False).state_dict()["weight"]
        self.prototype = nn.Parameter(temp)
        self.scaling = torch.nn.Parameter(torch.tensor([1.0]))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.linear1(x))
        feature_embedding = F.relu(self.linear2(x))
        feature_embedding_norm = torch.norm(
            feature_embedding, p=2, dim=1, keepdim=True
        ).clamp(min=1e-12)
        feature_embedding = torch.div(feature_embedding, feature_embedding_norm)
        if self.prototype.requires_grad == False:
            normalized_prototype = self.prototype
        else:
            prototype_norm = torch.norm(self.prototype, p=2, dim=1, keepdim=True).clamp(
                min=1e-12
            )
            normalized_prototype = torch.div(self.prototype, prototype_norm)
        logits = torch.matmul(feature_embedding, normalized_prototype.T)
        logits = self.scaling * logits

        if self.return_embedding:
            return feature_embedding, logits
        else:
            return logits

在聚类采样策略的方法上，我们的方法FedPCG结合了模型参数解耦策略，实现了基于head层代表性梯度的聚类采样，有效降低了计算复杂度。reference：Fraboni, Y., Vidal, R., Kameni, L., & Lorenzi, M. (2021, July). Clustered sampling: Low-variance and improved representativity for clients selection in federated learning. In International Conference on Machine Learning (pp. 3407-3416). PMLR.

In [None]:
class ClientSelection:
    def __init__(self, total, device="cpu"):
        self.total = total
        self.device = device

    def select(self, n, client_idxs, metric):
        pass

    def save_selected_clients(self, client_idxs, results):
        tmp = np.zeros(self.total)
        tmp[client_idxs] = 1
        tmp.tofile(results, sep=",")
        results.write("\n")

    def save_results(self, arr, results, prefix=""):
        results.write(prefix)
        np.array(arr).astype(np.float32).tofile(results, sep=",")
        results.write("\n")


"""Clustered Sampling Algorithm 1"""


class ClusteredSampling1(ClientSelection):
    def __init__(self, total, device, n_cluster):
        super().__init__(total, device)
        self.n_cluster = n_cluster

    def setup(self, n_samples):
        """
        Since clustering is performed according to the clients sample size n_i,
        unless n_i changes during the learning process,
        Algo 1 needs to be run only once at the beginning of the learning process.
        """
        epsilon = int(10**10)
        client_ids = sorted(n_samples.keys())
        n_samples = np.array([n_samples[i] for i in client_ids])
        weights = n_samples / np.sum(n_samples)
        # associate each client to a cluster
        augmented_weights = np.array([w * self.n_cluster * epsilon for w in weights])
        ordered_client_idx = np.flip(np.argsort(augmented_weights))

        distri_clusters = np.zeros((self.n_cluster, self.total)).astype(int)
        k = 0
        for client_idx in ordered_client_idx:
            while augmented_weights[client_idx] > 0:
                sum_proba_in_k = np.sum(distri_clusters[k])
                u_i = min(epsilon - sum_proba_in_k, augmented_weights[client_idx])
                distri_clusters[k, client_idx] = u_i
                augmented_weights[client_idx] += -u_i
                sum_proba_in_k = np.sum(distri_clusters[k])
                if sum_proba_in_k == 1 * epsilon:
                    k += 1

        distri_clusters = distri_clusters.astype(float)
        for l in range(self.n_cluster):
            distri_clusters[l] /= np.sum(distri_clusters[l])

        self.distri_clusters = distri_clusters

    def select(self, n, client_idxs, metric=None):
        selected_client_idxs = []
        for k in range(n):
            weight = np.take(self.distri_clusters[k], client_idxs)
            selected_client_idxs.append(
                int(np.random.choice(client_idxs, 1, p=weight / sum(weight)))
            )
        return np.array(selected_client_idxs)


"""Clustered Sampling Algorithm 2"""


class ClusteredSampling2(ClientSelection):
    def __init__(self, total, device, dist):
        super().__init__(total, device)
        self.distance_type = dist

    def setup(self, n_samples):
        """
        return the `representative gradient` formed by the difference
        between the local work and the sent global model
        """
        client_ids = sorted(n_samples.keys())
        n_samples = np.array([n_samples[i] for i in client_ids])
        print("n_samples", n_samples)
        self.weights = n_samples / sum(n_samples)

    def init(self, global_m, local_models):
        self.prev_global_m = global_m
        self.gradients = self.get_gradients(global_m, local_models)

    def select(self, n, client_idxs, metric=None):
        # GET THE CLIENTS' SIMILARITY MATRIX
        sim_matrix = self.get_matrix_similarity_from_grads(
            self.gradients, distance_type=self.distance_type
        )
        # GET THE DENDROGRAM TREE ASSOCIATED
        linkage_matrix = linkage(sim_matrix, "ward")

        distri_clusters = self.get_clusters_with_alg2(linkage_matrix, n, self.weights)
        # sample clients
        selected_client_idxs = np.zeros(n, dtype=int)
        for k in range(n):
            selected_client_idxs[k] = int(
                np.random.choice(client_idxs, 1, p=distri_clusters[k])
            )

        return selected_client_idxs

    def update(self, clients_models, sampled_clients_for_grad):
        print(">> update gradients")
        # UPDATE THE HISTORY OF LATEST GRADIENT
        gradients_i = self.get_gradients(self.prev_global_m, clients_models)
        for idx, gradient in zip(sampled_clients_for_grad, gradients_i):
            self.gradients[idx] = gradient

    def get_gradients(self, global_m, local_models):
        """
        return the `representative gradient` formed by the difference
        between the local work and the sent global model
        """
        local_model_params = []
        for model in local_models:
            local_model_params += [
                [
                    tens.detach().to(self.device)
                    for tens in list(sf.reveal(model).parameters())
                ][0]
            ]  # .numpy()

        global_model_params = [
            tens.detach().to(self.device) for tens in list(sf.reveal(global_m).values())
        ][0]

        local_model_grads = []
        for local_params in local_model_params:
            local_model_grads += [
                [
                    local_weights - global_weights
                    for local_weights, global_weights in zip(
                        local_params, global_model_params
                    )
                ]
            ]

        return local_model_grads

    def get_matrix_similarity_from_grads(self, local_model_grads, distance_type):
        """
        return the similarity matrix where the distance chosen to
        compare two clients is set with `distance_type`
        """
        n_clients = len(local_model_grads)
        metric_matrix = torch.zeros((n_clients, n_clients))
        for i, j in tqdm(
            product(range(n_clients), range(n_clients)), desc=">> similarity", ncols=80
        ):
            metric_matrix[i, j] = self.get_similarity(
                local_model_grads[i], local_model_grads[j], distance_type
            )

        return metric_matrix

    def get_similarity(self, grad_1, grad_2, distance_type="L1"):
        if distance_type == "L1":
            norm = 0
            for g_1, g_2 in zip(grad_1, grad_2):
                norm += torch.sum(torch.abs(g_1 - g_2))
            return norm.cpu().data

        elif distance_type == "L2":
            norm = 0
            for g_1, g_2 in zip(grad_1, grad_2):
                norm += np.sum((g_1 - g_2) ** 2)
            return norm

        elif distance_type == "cosine":
            norm, norm_1, norm_2 = 0, 0, 0
            print(grad_1[0].squeeze().shape, grad_2[0].squeeze().shape)
            for i in range(len(grad_1)):
                norm += np.sum(torch.mul(grad_1[i].squeeze(), grad_2[i].squeeze()))
                norm_1 += np.sum(grad_1[i] ** 2)
                norm_2 += np.sum(grad_2[i] ** 2)

            if norm_1 == 0.0 or norm_2 == 0.0:
                return 0.0
            else:
                norm /= np.sqrt(norm_1 * norm_2)
                return np.arccos(norm)

    def get_clusters_with_alg2(
        self, linkage_matrix: np.array, n_sampled: int, weights: np.array
    ):
        """Algorithm 2"""
        epsilon = int(10**10)

        # associate each client to a cluster
        link_matrix_p = deepcopy(linkage_matrix)
        augmented_weights = deepcopy(weights)

        for i in range(len(link_matrix_p)):
            idx_1, idx_2 = int(link_matrix_p[i, 0]), int(link_matrix_p[i, 1])

            new_weight = np.array(
                [
                    sf.reveal(augmented_weights)[idx_1]
                    + sf.reveal(augmented_weights)[idx_2]
                ]
            )
            augmented_weights = np.concatenate((augmented_weights, new_weight))
            link_matrix_p[i, 2] = int(new_weight * epsilon)

        clusters = fcluster(
            link_matrix_p, int(epsilon / n_sampled), criterion="distance"
        )

        n_clients, n_clusters = len(clusters), len(set(clusters))

        # Associate each cluster to its number of clients in the cluster
        pop_clusters = np.zeros((n_clusters, 2), dtype=np.int64)
        for i in range(n_clusters):
            pop_clusters[i, 0] = i + 1
            for client in np.where(clusters == i + 1)[0]:
                pop_clusters[i, 1] += int(weights[client] * epsilon * n_sampled)

        pop_clusters = pop_clusters[pop_clusters[:, 1].argsort()]

        distri_clusters = np.zeros((n_sampled, n_clients), dtype=np.int64)

        # n_sampled biggest clusters that will remain unchanged
        kept_clusters = pop_clusters[n_clusters - n_sampled :, 0]

        for idx, cluster in enumerate(kept_clusters):
            for client in np.where(clusters == cluster)[0]:
                distri_clusters[idx, client] = int(
                    weights[client] * n_sampled * epsilon
                )

        k = 0
        for j in pop_clusters[: n_clusters - n_sampled, 0]:
            clients_in_j = np.where(clusters == j)[0]
            np.random.shuffle(clients_in_j)

            for client in clients_in_j:
                weight_client = int(weights[client] * epsilon * n_sampled)

                while weight_client > 0:
                    sum_proba_in_k = np.sum(distri_clusters[k])
                    u_i = min(epsilon - sum_proba_in_k, weight_client)
                    distri_clusters[k, client] = u_i
                    weight_client += -u_i
                    sum_proba_in_k = np.sum(distri_clusters[k])
                    if sum_proba_in_k == 1 * epsilon:
                        k += 1

        distri_clusters = distri_clusters.astype(float)
        print(distri_clusters.shape)
        for l in range(n_sampled):
            distri_clusters[l] /= np.sum(distri_clusters[l])

        return distri_clusters

继承自FedNH中的一些工具函数

In [None]:
def autoassign(lcls):
    """
    Map all inputs to class attributes.
    Reference: https://stackoverflow.com/questions/3652851/what-is-the-best-way-to-do-automatic-attribute-assignment-in-python-and-is-it-a
    """
    for key in lcls.keys():
        if key != "self":
            # flattern kwargs
            if key == "kwargs":
                if key in lcls["self"].__dict__:
                    for k in lcls["self"].__dict__[key]:
                        lcls["self"].__dict__[k] = lcls["self"].__dict__[key][k]
            else:
                lcls["self"].__dict__[key] = lcls[key]


def calculate_model_size(model_state_dict):
    """Show model size in MB"""

    mdict = model_state_dict
    mem = sum(
        [mdict[key].nelement() * mdict[key].element_size() for key in mdict.keys()]
    )
    return mem * 1e-6


def calculate_flops(model, inputs_size, device):
    """inputs_size: bacth size 1 input"""
    stat = summary(model, inputs_size, verbose=0, device=device)
    return stat.total_mult_adds


def save_to_pkl(obj, path):
    with open(path, "wb") as file:
        pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)


def load_from_pkl(path):
    with open(path, "rb") as file:
        data = pickle.load(file)
    return data


def mkdirs(dirpath):
    if not os.path.exists(dirpath):
        # multi-threading
        os.makedirs(dirpath, exist_ok=True)


def access_last_added_element(ordered_dict):
    """
    next(reversed(ordered_dict)) returns the last added key
    """
    try:
        key = next(reversed(ordered_dict))
        return ordered_dict[key]
    except StopIteration:
        # print("The OrderedDict is empty.")
        return None


class Initializer:
    """
    ref:
    1. https://github.com/3ammor/Weights-Initializer-pytorch/blob/master/weight_initializer.py
    2. https://github.com/kevinzakka/pytorch-goodies
    """

    def __init__(self):
        pass

    @staticmethod
    def initialize(model, initialization, **kwargs):
        def weights_init(m):
            if isinstance(m, nn.Conv2d):
                initialization(m.weight.data, **kwargs)
                try:
                    initialization(m.bias.data)
                except:
                    pass

            elif isinstance(m, nn.Linear):
                initialization(m.weight.data, **kwargs)
                try:
                    initialization(m.bias.data)
                except:
                    pass

        model.apply(weights_init)


"""
Split Datasets

References:
1. https://github.com/Xtra-Computing/NIID-Bench
"""


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class."""

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]
        self.targets = dataset.targets[self.idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        # return torch.tensor(image), torch.tensor(label)
        return image, label


def remove_by_class(trainset, list_of_classes_to_remove):
    for cls in list_of_classes_to_remove:
        selected = trainset.targets != cls
        trainset.idxs = list(compress(trainset.idxs, selected))
        trainset.targets = trainset.dataset.targets[trainset.idxs]
    return trainset


def split_trainset_by_class(client_trainset):
    """
    Input: client_trainset, which is an object of DatasetSplit class
    Return: a dictionary of trainset, where the key is the label while the value is an object of the DatasetSplit class
    """
    all_classes = torch.unique(client_trainset.targets).tolist()
    class_dataset_dict = {}
    for c in all_classes:
        selected = client_trainset.targets == c
        idx_c = list(compress(client_trainset.idxs, selected))
        class_dataset_dict[c] = DatasetSplit(client_trainset.dataset, idx_c)
    return class_dataset_dict


def sampler(dataset, num_clients, partition, seed=None, minsize=10, **kwargs):
    """
    dataset: torch.utils.data.Dataset object
    partition:
        iid-equal-size:
            uniformly randomly sample from the whole datasets and each party approximately has the same number of samples.

        iid-diff-size:
            uniformly randomly sample from the whole datasets but each party approximately has different number of samples;
            should also set the beta parameter.

        noniid-label-quantity:
            Each client will only contain `num_classes_per_client` classes of samples; for any two clients that have the same class, the samples will not overlap;
            should also set the num_class, num_classes_per_client, and ylabels parameters.
            samples in each classes are uniformly diveded. But this could still lead to class imbalance in each client.
            See https://arxiv.org/pdf/2102.02079.pdf a) Quantity-based label imbalance

        noniid-label-distribution:
            The number of classes per client own follow a dirichlet distribution with concentration parameter beta.
            should also set the num_class, beta, and ylabels parameters.

        shards:
            Suppose there are (n clients, c classes, N datapoints) and each clients own s shards of data. Then the
            number of data per shard is  size_s = N / (n * s). And each class is split into c/size_s shards.
            Shards are randomly assigned to clients.

    kwargs:
        beta: concentration parameter for the **symmetric** Dirichlet distribution; float, larger than 0
        ylabels: 1d tensor of size as the len(dataset)
        num_class: int; larger than 0
        num_classes_per_client: int;  larger than 0 smaller than total number of classes in ylabels

    Return: a dict; {cid: torch.utils.data.Dataset object}

    --- Notes ---
    Effect of the beta parameter:
        When beta = 1, the symmetric Dirichlet distribution is equivalent to a uniform distribution over the open standard (K − 1)-simplex, (the distribution over distributions is uniform)
        When beta > 1, it prefers variates that are dense, evenly distributed distributions, i.e. all the values within a single sample are similar to each other.
        When beta < 1, it prefers sparse distributions, i.e. most of the values within a single sample will be close to 0, and the vast majority of the mass will be concentrated in a few of the values.

    --- References ---
    1. https://en.wikipedia.org/wiki/Dirichlet_distribution#The_concentration_parameter
    """
    # process arguments
    if partition in ["iid-diff-size", "noniid-label-distribution"]:
        if "beta" not in kwargs:
            beta = 0.5
            warnings.warn(f"partition:{partition} | beta is not provided. Set to 0.5.")
        else:
            beta = kwargs["beta"]
            temp = beta.split("b")
            if len(temp) == 1:
                beta = float(temp[0])
                is_balanced = False
            elif len(temp) == 2:
                beta = float(temp[0])
                is_balanced = True
            assert beta > 0, "beta needs to be non-negative"
    if partition == "shards":
        if "num_shards_per_client" not in kwargs:
            raise ValueError(
                f"The num_shards_per_client parameter needs to be set for the partition {partition}."
            )
        else:
            num_classes = kwargs["num_classes"]

    if partition in ["noniid-label-quantity", "noniid-label-distribution"]:
        if "num_classes" not in kwargs:
            raise ValueError(
                f"The num_classes parameter needs to be set for the partition {partition}."
            )
        else:
            num_classes = kwargs["num_classes"]
        try:
            num_unique_class = len(torch.unique(dataset.targets))
        except TypeError:
            print("dataset.targets is not of tensor type! Proper actions are required.")
            exit()
        assert (
            num_classes == num_unique_class
        ), f"num_classes is set to {num_classes}, but number of unique class detected in ylables are {num_unique_class}."
        if "ylabels" not in kwargs:
            raise ValueError(
                f"The ylabels parameter needs to be set for the partition {partition}."
            )
        else:
            ylabels = kwargs["ylabels"]
    if seed is not None:
        np.random.seed(seed)
    num_samples = len(dataset)
    idxs = np.random.permutation(num_samples)
    cur_minsize = 0
    attemp = 0
    max_attemp = 3
    stats_dict = {}
    if partition == "iid-equal-size":
        batch_idxs = np.array_split(idxs, num_clients)
        if len(batch_idxs[-1]) < minsize:
            warnings.warn(
                f"partition:{partition} | Some clients have less than {minsize} samples. Check it before continue."
            )
        cid_idxlst_dict = {cid: batch_idxs[cid].tolist() for cid in range(num_clients)}
    elif partition == "iid-diff-size":
        """
        The number of samples per client follow a dirichlet distribution with concentration parameter beta.
        But the number of samples per classes in each client are approxumately the same
        """
        while cur_minsize < minsize:
            attemp += 1
            if attemp == max_attemp:
                raise RuntimeError(
                    f"partition:{partition} | Exceeds max allowed attempts. Consider change the random seed."
                )
            proportions = np.random.dirichlet(np.repeat(beta, num_clients))
            proportions = proportions / proportions.sum()
            cur_minsize = np.min(proportions * len(idxs))

        proportions_to_num = (np.cumsum(proportions) * len(idxs)).astype(int)[:-1]
        batch_idxs = np.split(idxs, proportions_to_num)
        cid_idxlst_dict = {i: batch_idxs[i].tolist() for i in range(num_clients)}
        stats_dict["proportions"] = proportions
    elif partition == "noniid-label-quantity":
        """
        Each client will only contain `num_classes_per_client` classes of samples.
        For any two clients that have the same class, the samples will not overlap.
        """
        # use user supplied partition
        if (
            "assigned_clients_per_class" in kwargs
            and "assigned_classes_per_client" in kwargs
        ):
            assigned_clients_per_class = kwargs["assigned_clients_per_class"]
            assigned_classes_per_client = kwargs["assigned_classes_per_client"]
            assert (
                type(assigned_clients_per_class) == list
            ), "assigned_clients_per_class has to a list"
            assert (
                type(assigned_classes_per_client) == list
            ), "assigned_classes_per_client has to a list"
            assert (
                type(assigned_classes_per_client[0]) == set
            ), "the elements of assigned_classes_per_client has to a set"
            cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
            num_classes_per_client = [len(s) for s in assigned_classes_per_client]
        else:
            if "num_classes_per_client" not in kwargs:
                raise ValueError(
                    f"The num_classes_per_client parameter needs to be set for the partition {partition}."
                )
            else:
                num_classes_per_client = kwargs["num_classes_per_client"]
            assert (
                num_classes_per_client <= num_classes
            ), "`num_classes_per_client` should be no bigger than `num_classes`"
            cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
            assigned_clients_per_class = [0 for i in range(num_classes)]
            assigned_classes_per_client = []
            for cid in range(num_clients):
                # assign class `class_idx` to client `cid`
                class_idx = cid % num_classes
                current = set()
                current.add(class_idx)
                assigned_clients_per_class[class_idx] += 1
                assigned_class_count = 1
                while assigned_class_count < num_classes_per_client:
                    ind = np.random.randint(0, num_classes)
                    if ind not in current:
                        assigned_class_count += 1
                        current.add(ind)
                        assigned_clients_per_class[ind] += 1
                assigned_classes_per_client.append(current)

        missing_classes = []
        for k in range(num_classes):
            if assigned_clients_per_class[k] == 0:
                missing_classes.append(str(k))
        if len(missing_classes) > 0:
            warnings.warn(
                "Classes "
                + ",".join(missing_classes)
                + "are not used. Consider increase either num_clients or num_classes_per_client."
            )
        for k in range(num_classes):
            idx_k = np.where(ylabels == k)[0]
            np.random.shuffle(idx_k)
            try:
                split = np.array_split(idx_k, assigned_clients_per_class[k])
            except ValueError:
                pass
            ids = 0
            for cid in range(num_clients):
                if k in assigned_classes_per_client[cid]:
                    cid_idxlst_dict[cid] += split[ids].tolist()
                    ids += 1
        stats_dict["num_classes"] = num_classes
        stats_dict["num_classes_per_client"] = num_classes_per_client
        stats_dict["assigned_classes_per_client"] = assigned_classes_per_client
        stats_dict["assigned_clients_per_class"] = assigned_clients_per_class
    elif partition == "noniid-label-distribution":
        """
        The number of classes per client own follow a dirichlet distribution with concentration parameter beta.
        feddf: https://github.com/epfml/federated-learning-public-code/blob/7e002ef5ff0d683dba3db48e2d088165499eb0b9/codes/FedDF-code/pcode/datasets/partition_data.py#L197
        """
        if is_balanced:
            np.random.seed(2022)
            server_config = kwargs["server_config"]
            save_dir = f"../experiments/datapartition/{server_config['dataset']}_{beta}b_{server_config['num_clients']}.pkl"
            print("Doing balanced dir sampling")
            if os.path.exists(save_dir):
                print("Partition is found!")
                cid_idxlst_dict = load_from_pkl(save_dir)
            else:
                n_data_per_clnt = int(num_samples / num_clients)
                clnt_data_list = (np.ones(num_clients) * n_data_per_clnt).astype(int)
                cls_priors = np.random.dirichlet(
                    alpha=[beta] * num_classes, size=num_clients
                )
                prior_cumsum = np.cumsum(cls_priors, axis=1)
                idx_list = [np.where(ylabels == i)[0] for i in range(num_classes)]
                cls_amount = [len(idx_list[i]) for i in range(num_classes)]
                cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
                while np.sum(clnt_data_list) != 0:
                    curr_clnt = np.random.randint(num_clients)
                    # If current node is full resample a client
                    print("Remaining Data: %d" % np.sum(clnt_data_list))
                    if clnt_data_list[curr_clnt] <= 0:
                        continue
                    clnt_data_list[curr_clnt] -= 1
                    curr_prior = prior_cumsum[curr_clnt]
                    while True:
                        cls_label = np.argmax(np.random.uniform() <= curr_prior)
                        # Redraw class label if trn_y is out of that class
                        if cls_amount[cls_label] <= 0:
                            continue
                        cls_amount[cls_label] -= 1
                        cid_idxlst_dict[curr_clnt].append(
                            idx_list[cls_label][cls_amount[cls_label]]
                        )
                        break
                mkdirs("../experiments/datapartition/")
                save_to_pkl(cid_idxlst_dict, save_dir)
                print("cid_idxlst_dict is saved to", save_dir)
        else:
            resample = False
            while cur_minsize < minsize or resample:
                attemp += 1
                if attemp > max_attemp:
                    count = 0
                    for cid in range(num_clients):
                        if allocated_classes[cid] <= 1:
                            count += 1
                    print(f" Warning: {count} clients have less than 2 classes")
                    break
                batch_idxs = [[] for _ in range(num_clients)]
                allocated_classes = [0] * num_clients
                for k in range(num_classes):
                    idx_k = np.where(ylabels == k)[0]
                    np.random.shuffle(idx_k)
                    # determine the fraction of samples in class k for each client;
                    proportions = np.random.dirichlet(np.repeat(beta, num_clients))
                    # if number of samples in client j is already larger than the threshold num_samples / num_clients
                    # then the client won't contain any new class including the current class k
                    proportions = np.array(
                        [
                            p * (len(allocated_idxs) < num_samples / num_clients)
                            for p, allocated_idxs in zip(proportions, batch_idxs)
                        ]
                    )
                    proportions = proportions / proportions.sum()
                    stats_dict[f"proportions_{k}"] = proportions
                    proportions_to_num = (np.cumsum(proportions) * len(idx_k)).astype(
                        int
                    )[:-1]
                    # reference: https://numpy.org/doc/stable/reference/generated/numpy.split.html
                    # batch_idxs = [allocated_idxs + idx.tolist() for allocated_idxs,
                    #               idx in zip(batch_idxs, np.split(idx_k, proportions_to_num))]
                    chunks = np.split(idx_k, proportions_to_num)
                    # hack to fix class deficiency in some clients
                    if k >= 2:
                        if min(allocated_classes) <= 1:
                            cid_has_only_one_or_less_class = []
                            for cid in range(num_clients):
                                if allocated_classes[cid] <= 1:
                                    cid_has_only_one_or_less_class.append(cid)
                            replace_index = -1
                            for cid in cid_has_only_one_or_less_class:
                                temp_chunk = chunks[cid]
                                temp_ratio = proportions[cid]
                                chunks[cid] = chunks[replace_index]
                                chunks[replace_index] = temp_chunk
                                proportions[cid] = proportions[replace_index]
                                proportions[replace_index] = temp_ratio
                                replace_index -= 1
                    cid = 0
                    for allocated_idxs, idx in zip(batch_idxs, chunks):
                        added_samples = idx.tolist()
                        if len(added_samples) > 0:
                            allocated_idxs += added_samples
                            allocated_classes[cid] += 1
                        cid += 1
                    cur_minsize = min(
                        [len(allocated_idxs) for allocated_idxs in batch_idxs]
                    )
                if min(allocated_classes) <= 1:
                    resample = True
                    print(
                        " [Info - Dirichlet Sampling]: At leaset one client only has one class label. Perform Resampling..."
                    )
                else:
                    resample = False
            cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
            for cid in range(num_clients):
                np.random.shuffle(batch_idxs[cid])
                cid_idxlst_dict[cid] = batch_idxs[cid]
        stats_dict["num_classes"] = num_classes
    elif partition == "shards":
        num_shards_per_client = kwargs["num_shards_per_client"]
        dict_users, stats_dict["rand_set_all"] = sshards(
            dataset,
            num_clients,
            num_shards_per_client,
            server_data_ratio=0.0,
            rand_set_all=[],
        )
        cid_idxlst_dict = {i: dict_users[i].tolist() for i in range(num_clients)}
    else:
        raise ValueError(f"partition:{partition} is not recognized.")
    # generate a set of sub-datasets
    dataset_per_client_dict = {
        cid: DatasetSplit(dataset, cid_idxlst_dict[cid]) for cid in range(num_clients)
    }
    stats_dict["num_clients"] = num_clients
    stats_dict["partition"] = partition
    stats_dict["seed"] = seed
    stats_dict["minsize"] = minsize
    return dataset_per_client_dict, stats_dict


def sampler_reuse(dataset, stats_dict, **kwargs):
    partition = stats_dict["partition"]
    num_clients = stats_dict["num_clients"]
    if stats_dict["seed"] is not None:
        np.random.seed(stats_dict["seed"])
    if partition in ["noniid-label-quantity", "noniid-label-distribution"]:
        num_classes = stats_dict["num_classes"]
        num_unique_class = len(torch.unique(dataset.targets))
        assert (
            num_classes == num_unique_class
        ), f"num_class is set to, but number of unique class detected in ylables are {num_unique_class}. The dataset may have a different distribution!"
        if "ylabels" not in kwargs:
            raise ValueError(
                f"The ylabels parameter needs to be set for the partition {partition}."
            )
        else:
            ylabels = kwargs["ylabels"]
    num_samples = len(dataset)
    idxs = np.random.permutation(num_samples)
    cur_minsize = 0
    attemp = 0
    max_attemp = 100
    if partition == "iid-equal-size":
        batch_idxs = np.array_split(idxs, num_clients)
        if len(batch_idxs[-1]) < stats_dict["minsize"]:
            warnings.warn(
                f"partition:{partition} | Some clients have less than {stats_dict['minsize']} samples. Check it before continue."
            )
        cid_idxlst_dict = {cid: batch_idxs[cid].tolist() for cid in range(num_clients)}
    elif partition == "iid-diff-size":
        proportions_to_num = (np.cumsum(stats_dict["proportions"]) * len(idxs)).astype(
            int
        )[:-1]
        batch_idxs = np.split(idxs, proportions_to_num)
        cid_idxlst_dict = {i: batch_idxs[i].tolist() for i in range(num_clients)}
    elif partition == "noniid-label-quantity":
        num_classes_per_client = stats_dict["num_classes_per_client"]
        assert (
            num_classes_per_client <= num_classes
        ), "`num_classes_per_client` should be no bigger than `num_classes`"
        cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
        for k in range(num_classes):
            idx_k = np.where(ylabels == k)[0]
            np.random.shuffle(idx_k)
            try:
                split = np.array_split(
                    idx_k, stats_dict["assigned_clients_per_class"][k]
                )
            except ValueError:
                pass
            ids = 0
            for cid in range(num_clients):
                if k in stats_dict["assigned_classes_per_client"][cid]:
                    cid_idxlst_dict[cid] += split[ids].tolist()
                    ids += 1
    elif partition == "noniid-label-distribution":
        batch_idxs = [[] for _ in range(num_clients)]
        for k in range(num_classes):
            idx_k = np.where(ylabels == k)[0]
            np.random.shuffle(idx_k)
            proportions = stats_dict[f"proportions_{k}"]
            proportions_to_num = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            batch_idxs = [
                allocated_idxs + idx.tolist()
                for allocated_idxs, idx in zip(
                    batch_idxs, np.split(idx_k, proportions_to_num)
                )
            ]

        cid_idxlst_dict = {cid: [] for cid in range(num_clients)}
        for cid in range(num_clients):
            np.random.shuffle(batch_idxs[cid])
            cid_idxlst_dict[cid] = batch_idxs[cid]
    else:
        raise ValueError(f"partition:{partition} is not recognized.")
    # generate a set of sub-datasets
    dataset_per_client_dict = {
        cid: DatasetSplit(dataset, cid_idxlst_dict[cid]) for cid in range(num_clients)
    }
    return dataset_per_client_dict


def sshards(dataset, num_users, shard_per_user, server_data_ratio, rand_set_all=[]):
    setup_seed(2022)
    dict_users, all_idxs = {i: np.array([], dtype="int64") for i in range(num_users)}, [
        i for i in range(len(dataset))
    ]

    idxs_dict = {}
    for i in range(len(dataset)):
        label = dataset.targets[i].item()
        if label not in idxs_dict.keys():
            idxs_dict[label] = []
        # collect all data in class ``label``
        idxs_dict[label].append(i)

    num_classes = len(np.unique(dataset.targets))
    shard_per_class = int(shard_per_user * num_users / num_classes)
    for label in idxs_dict.keys():
        x = idxs_dict[label]
        num_leftover = len(x) % shard_per_class
        leftover = x[-num_leftover:] if num_leftover > 0 else []
        x = np.array(x[:-num_leftover]) if num_leftover > 0 else np.array(x)
        x = x.reshape((shard_per_class, -1))
        x = list(x)

        for i, idx in enumerate(leftover):
            x[i] = np.concatenate([x[i], [idx]])
        idxs_dict[label] = x

    if len(rand_set_all) == 0:
        rand_set_all = list(range(num_classes)) * shard_per_class
        random.shuffle(rand_set_all)
        rand_set_all = np.array(rand_set_all).reshape((num_users, -1))

    # divide and assign
    for i in range(num_users):
        rand_set_label = rand_set_all[i]
        rand_set = []
        for label in rand_set_label:
            idx = np.random.choice(len(idxs_dict[label]), replace=False)
            rand_set.append(idxs_dict[label].pop(idx))
        dict_users[i] = np.concatenate(rand_set)

    test = []
    for key, value in dict_users.items():
        x = np.unique(dataset.targets[value])
        assert (len(x)) <= shard_per_user
        test.append(value)
    test = np.concatenate(test)
    assert len(test) == len(dataset)
    assert len(set(list(test))) == len(dataset)

    if server_data_ratio > 0.0:
        dict_users["server"] = set(
            np.random.choice(
                all_idxs, int(len(dataset) * server_data_ratio), replace=False
            )
        )
    # print(dict_users)
    # exit()
    return dict_users, rand_set_all


"""
visualization tools
"""


def visualize_sampling(dataset_per_client_dict, num_classes, figsize=(10, 8), **kwargs):
    num_clients = len(dataset_per_client_dict)
    mat = np.zeros((num_clients, num_classes))
    targets = dataset_per_client_dict[0].dataset.targets
    for key in dataset_per_client_dict.keys():
        subset = dataset_per_client_dict[key]
        for k in range(num_classes):
            num_samples = torch.sum(torch.eq(targets[subset.idxs], k)).item()
            mat[key, k] = num_samples
    fig, ax = plt.subplots(figsize=figsize)

    im, _ = heatmap(
        mat,
        np.arange(num_clients),
        np.arange(num_classes),
        ax=ax,
        cmap="YlGn",
        cbarlabel="#Samples",
    )
    _ = annotate_heatmap(im, valfmt="{x:.0f}")

    fig.tight_layout()
    if "fig_path_name" in kwargs:
        fig_path_name = kwargs["fig_path_name"]
        dirpath = "/".join(fig_path_name.split("/")[:-1])
        mkdirs(dirpath)
        plt.savefig(fig_path_name)
    else:
        plt.show()
    return mat


def heatmap(data, x_labels, y_labels, ax=None, cbar_kw={}, cbarlabel="", **kwargs):

    if not ax:
        ax = plt.gca()

    im = ax.imshow(data.T, **kwargs)
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
    ax.set_xticks(np.arange(data.shape[0]))
    ax.set_xticklabels(x_labels)
    ax.set_xlabel("Client ID")
    ax.set_yticks(np.arange(data.shape[1]))
    ax.set_yticklabels(y_labels)
    ax.set_ylabel("Class label")

    ax.set_xticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
    ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(
    im,
    data=None,
    valfmt="{x:.2f}",
    textcolors=("black", "white"),
    threshold=None,
    **textkw,
):

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max()) / 2.0

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts


class MulGaussian(Dataset):
    def __init__(self, mean_lst, n_lst):
        k = len(mean_lst)
        self.data = None
        self.targets = None
        for i in range(k):
            m = MultivariateNormal(
                torch.tensor(mean_lst[i]), torch.eye(len(mean_lst[i]))
            )
            samples = m.sample(sample_shape=(n_lst[i],))
            labels = torch.ones((n_lst[i],), dtype=torch.int32) * i
            if i == 0:
                self.data = samples
                self.targets = labels
            else:
                self.data = torch.cat((self.data, samples))
                self.targets = torch.cat((self.targets, labels))
        self.targets = self.targets.type(torch.LongTensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.data[idx], self.targets[idx]


class Spiral(Dataset):
    def __init__(self, n_lst, sigma=0.5):
        k = len(n_lst)
        self.data = None
        self.targets = None
        for i in range(k):
            r = torch.linspace(1, 10, n_lst[i])  # radius
            t = (
                torch.linspace(
                    i / k * 2 * torch.pi, (i + 1) / k * 2 * torch.pi, n_lst[i]
                )
                + torch.rand(n_lst[i]) * sigma
            )
            x = r * torch.sin(t)
            y = r * torch.cos(t)
            samples = torch.stack((x, y), 1)
            labels = torch.ones((n_lst[i],), dtype=torch.int32) * i
            if i == 0:
                self.data = samples
                self.targets = labels
            else:
                self.data = torch.cat((self.data, samples))
                self.targets = torch.cat((self.targets, labels))
        self.targets = self.targets.type(torch.LongTensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.data[idx], self.targets[idx]


"""
TinyImageNet Dataset
"""
EXTENSION = "JPEG"
NUM_IMAGES_PER_CLASS = 500
CLASS_LIST_FILE = "wnids.txt"
VAL_ANNOTATION_FILE = "val_annotations.txt"


class TinyImageNet(Dataset):
    """
    Ref: https://github.com/leemengtaiwan/tiny-imagenet/blob/master/TinyImageNet.py
    Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
    Parameters
    ----------
    root: string
        Root directory including `train`, `test` and `val` subdirectories.
    split: string
        Indicating which split to return as a data set.
        Valid option: [`train`, `test`, `val`]
    transform: torchvision.transforms
        A (series) of valid transformation(s).
    in_memory: bool
        Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
    """

    def __init__(
        self,
        root,
        split="train",
        transform=None,
        target_transform=None,
        in_memory=False,
    ):
        self.root = os.path.expanduser(root)
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.in_memory = in_memory
        self.split_dir = os.path.join(self.root, self.split)
        self.image_paths = sorted(
            glob.iglob(
                os.path.join(self.split_dir, "**", "*.%s" % EXTENSION), recursive=True
            )
        )
        self.labels = {}  # fname - label number mapping
        self.images = []  # used for in-memory processing
        # build class label - number mapping
        with open(os.path.join(self.root, CLASS_LIST_FILE), "r") as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}

        if self.split == "train":
            for label_text, i in self.label_text_to_number.items():
                for cnt in range(NUM_IMAGES_PER_CLASS):
                    self.labels["%s_%d.%s" % (label_text, cnt, EXTENSION)] = i
        elif self.split == "val":
            with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), "r") as fp:
                for line in fp.readlines():
                    terms = line.split("\t")
                    file_name, label_text = terms[0], terms[1]
                    self.labels[file_name] = self.label_text_to_number[label_text]

        # get targets
        self.targets = []
        for index in range(len(self.image_paths)):
            file_path = self.image_paths[index]
            label_numeral = self.labels[os.path.basename(file_path)]
            self.targets.append(label_numeral)

        # read all images into torch tensor in memory to minimize disk IO overhead
        if self.in_memory:
            self.images = [self.read_image(path) for path in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        file_path = self.image_paths[index]

        if self.in_memory:
            img = self.images[index]
        else:
            img = self.read_image(file_path)

        if self.split == "test":
            return img
        else:
            return img, self.labels[os.path.basename(file_path)]

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = self.split
        fmt_str += "    Split: {}\n".format(tmp)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str

    def read_image(self, path):
        img = Image.open(path)
        img = img.convert("RGB")
        return self.transform(img) if self.transform else img


"""
get datasets
"""


def get_datasets(datasetname, **kwargs):
    invTrans = None
    if datasetname == "FashionMnist":
        transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        trainset = torchvision.datasets.FashionMNIST(
            root="./data", train=True, download=True, transform=transform
        )
        testset = torchvision.datasets.FashionMNIST(
            root="./data", train=False, download=True, transform=transform
        )
    elif datasetname == "Cifar10":
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        transform_test = transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        invTrans = transforms.Compose(
            [
                transforms.Normalize(
                    mean=[0.0, 0.0, 0.0], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]
                ),
                transforms.Normalize(
                    mean=[-0.4914, -0.4822, -0.4465], std=[1.0, 1.0, 1.0]
                ),
            ]
        )
        trainset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform_train
        )
        testset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform_test
        )
        trainset.targets = torch.tensor(trainset.targets)
        testset.targets = torch.tensor(testset.targets)
    elif datasetname == "Cifar100":
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]
                ),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]
                ),
            ]
        )

        trainset = torchvision.datasets.CIFAR100(
            root="./data", train=True, download=True, transform=transform_train
        )
        testset = torchvision.datasets.CIFAR100(
            root="./data", train=False, download=True, transform=transform_test
        )
        trainset.targets = torch.tensor(trainset.targets)
        testset.targets = torch.tensor(testset.targets)

    elif datasetname == "TinyImageNet":
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        trainset = TinyImageNet(
            "./data/tiny-imagenet-200",
            "train",
            transform=transform_train,
            in_memory=False,
        )
        testset = TinyImageNet(
            "./data/tiny-imagenet-200", "val", transform=transform_test, in_memory=False
        )
        trainset.targets = torch.tensor(trainset.targets)
        testset.targets = torch.tensor(testset.targets)

    elif datasetname == "Cifar10Aug":
        """
        On Bridging Generic and Personalized Federated Learning for Image Classification impl
        """
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]
                ),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]
                ),
            ]
        )
        trainset = torchvision.datasets.CIFAR10(
            root="~/data", train=True, download=True, transform=transform_train
        )
        testset = torchvision.datasets.CIFAR10(
            root="~/data", train=False, download=True, transform=transform_test
        )
        trainset.targets = torch.tensor(trainset.targets)
        testset.targets = torch.tensor(testset.targets)
    elif datasetname == "GanEnhancedCifar10":
        trainset = GanEnhancedCifar10(
            kwargs["generator_path"], kwargs["dataset"], kwargs["upsample"]
        )
        transform = transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        testset = torchvision.datasets.CIFAR10(
            root="~/data", train=False, download=True, transform=transform
        )
        testset.targets = torch.tensor(testset.targets)
    else:
        raise ValueError(f"Unrecognized dataset:{datasetname}")

    return trainset, testset, invTrans

客户端划分等工具函数

In [None]:
def setup_optimizer(model, config, round):
    if config["client_lr_scheduler"] == "stepwise":
        if round < config["num_rounds"] // 2:
            lr = config["client_lr"]
        else:
            lr = config["client_lr"] * 0.1

    elif config["client_lr_scheduler"] == "diminishing":
        lr = config["client_lr"] * (config["lr_decay_per_round"] ** (round - 1))
    else:
        raise ValueError("unknown client_lr_scheduler")
    if config["optimizer"] == "SGD":
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=config["sgd_momentum"],
            weight_decay=config["sgd_weight_decay"],
        )
        # print('line 34: weight_decay=1e-3')
    elif config["optimizer"] == "Adam":
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            weight_decay=1e-5,
        )
    elif config["optimizer"] == "RMSprop":
        optimizer = torch.optim.RMSprop(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            alpha=config["rmsprop_alpha"],
            eps=1e-08,
            weight_decay=config["rmsprop_weight_decay"],
            momentum=config["rmsprop_momentum"],
        )
    else:
        raise ValueError(f"Unknown optimizer{config['optimizer']}")
    return optimizer


"""
client initialization
"""


def setup_clients(
    Client, trainset, testset, criterion, client_config_lst, client_pyus, **kwargs
):
    """
    Client is a class constructor.
    A deepcopy is invoked such that each client has a unique model.
    **kwargs:
        weight_init: {'init.normal'}
        mean: params for init.normal
        std: 0.1
    """
    num_clients = kwargs["server_config"]["num_clients"]
    partition = kwargs["server_config"]["partition"]
    num_classes = kwargs["server_config"]["num_classes"]
    assert (
        len(client_config_lst) == num_clients
    ), "Inconsistent num_clients and len(client_config_lst)."
    if "noniid" == partition[:6]:
        trainset_per_client_dict, stats_dict = sampler(
            trainset,
            num_clients,
            partition,
            ylabels=trainset.targets,
            num_classes=num_classes,
            **kwargs,
        )
        if testset is None:
            testset_per_client_dict = {cid: None for cid in range(num_clients)}
        else:
            if kwargs["same_testset"]:
                testset_per_client_dict = {cid: testset for cid in range(num_clients)}
            else:
                testset_per_client_dict = sampler_reuse(
                    testset,
                    stats_dict,
                    ylabels=testset.targets,
                    num_classes=num_classes,
                    **kwargs,
                )
    else:
        trainset_per_client_dict, stats_dict = sampler(
            trainset, num_clients, partition, num_classes=num_classes, **kwargs
        )
        if testset is None:
            testset_per_client_dict = {cid: None for cid in range(num_clients)}
        else:
            if kwargs["same_testset"]:
                testset_per_client_dict = {cid: testset for cid in range(num_clients)}
            else:
                testset_per_client_dict = sampler_reuse(
                    testset, stats_dict, num_classes=num_classes, **kwargs
                )

    n_samples = {cid: len(dataset) for cid, dataset in trainset_per_client_dict.items()}

    all_clients_dict = {}
    for cid in range(num_clients):
        # same initial weight
        setup_seed(2022)
        all_clients_dict[cid] = Client(
            device=client_pyus[cid],
            criterion=criterion,
            trainset=trainset_per_client_dict[cid],
            testset=testset_per_client_dict[cid],
            client_config=client_config_lst[cid],
            cid=cid,
            **kwargs,
        )

    return all_clients_dict, n_samples


def create_clients_from_existing_ones(
    Client, clients_dict, newtrainset, increment, criterion, **kwargs
):
    """
    Create new clients. All clients will maintain the same data distribution as specified in clients_dict.
    """
    num_clients = len(clients_dict)
    all_clients_dict = {}
    if "same_pool" not in kwargs:
        same_pool = False
    else:
        same_pool = kwargs["same_pool"]

    if "scale" not in kwargs:
        scale = len(newtrainset) // increment - 1
    else:
        scale = kwargs["scale"]

    for cid in range(num_clients):
        client = clients_dict[cid]
        data_idxs = client.trainset.idxs
        add_idxs = []
        if same_pool:
            for cls in client.count_by_class.keys():
                num_sample_cls = client.count_by_class[cls]
                target_num_sample_cls = min(
                    num_sample_cls * scale, len(newtrainset.get_fake_imgs_idxs(cls))
                )
                add_idxs += np.random.choice(
                    newtrainset.get_fake_imgs_idxs(cls),
                    target_num_sample_cls,
                    replace=False,
                ).tolist()
        else:
            for i in data_idxs:
                for j in range(scale):
                    add_idxs.append(i + increment * (j + 1))

        full_idxs = data_idxs + add_idxs
        client_newtrainset = DatasetSplit(newtrainset, full_idxs)
        all_clients_dict[cid] = Client(
            criterion,
            client_newtrainset,
            client.testset,
            client.client_config,
            client.cid,
            client.group,
            client.device,
            **kwargs,
        )
    return all_clients_dict


"""
resume training
"""


def resume_training(server_config, checkpoint, model):
    server = load_from_pkl(checkpoint)
    server.server_config = server_config
    for c in server.clients_dict.values():
        c.model = deepcopy(model)
        c.set_params(server.server_model_state_dict)
        c.model.to(c.device)
        c.model.init()
    print("Resume Training")
    print(f"Rounds performed:{server.rounds}")
    return server


"""
state_dict operation
"""


def scale_state_dict(this, scale, inplace=True, exclude=set()):
    with torch.no_grad():
        if not inplace:
            ans = deepcopy(this)
        else:
            ans = this
        for state_key in this.keys():
            if state_key not in exclude:
                ans[state_key] = this[state_key] * scale
        return ans


def linear_combination_state_dict(
    this, other, this_weight=1.0, other_weight=1.0, exclude=set()
):
    """
    this, other: state_dict
    this_weight * this + other_weight * other
    """
    with torch.no_grad():
        ans = deepcopy(this)
        for state_key in this.keys():
            if state_key not in exclude:
                # print('agg', state_key)
                ans[state_key] = (
                    this[state_key] * this_weight + other[state_key] * other_weight
                )
        return ans


def average_list_of_state_dict(state_dict_lst, exclude=set()):
    assert type(state_dict_lst) == list
    num_participants = len(state_dict_lst)
    keys = state_dict_lst[0].keys()
    with torch.no_grad():
        ans = OrderedDict()
        for key in keys:
            if state_key not in exclude:
                for idx, client_state_dict in enumerate(state_dict_lst):
                    if idx == 0:
                        # must do deepcopy; otherwise subsequent operation overwrittes the first client_state_dict
                        ans[key] = deepcopy(client_state_dict[key])
                    else:
                        ans[key] += client_state_dict[key]
                ans[key] = ans[key] / num_participants
    return ans


def weight_sum_of_dict_of_state_dict(dict_state_dict, weight_dict):
    layer_keys = next(iter(dict_state_dict.values())).keys()
    with torch.no_grad():
        ans = OrderedDict()
        for layer in layer_keys:
            count = 0
            for cid in dict_state_dict.keys():
                if count == 0:
                    # must do deepcopy; otherwise subsequent operation overwrittes the first client_state_dict
                    ans[layer] = (
                        deepcopy(dict_state_dict[cid][layer]) * weight_dict[cid]
                    )
                else:
                    ans[layer] += dict_state_dict[cid][layer] * weight_dict[cid]
                count += 1
    return ans

定义了联邦学习场景下的客户端 FedPCGClient 类和服务器 FedPCGServer类，实现了注入本地原型和全局原型的全局head更新的个性化联邦学习方法。

In [None]:
@proxy(PYUObject)
class FedPCGClient(Client):
    def __init__(self, criterion, trainset, testset, client_config, cid, **kwargs):
        super().__init__(criterion, trainset, testset, client_config, cid, **kwargs)
        self._initialize_model()
        self.device = "cpu"
        self.global_model = deepcopy(self.model)
        self.client_config = client_config
        self.beta = 1
        self.tau = 0.5
        self.num_classes = 10
        self.criterion = nn.CrossEntropyLoss()
        self.KLDiv = nn.KLDivLoss(reduction="batchmean")
        temp = [
            self.count_by_class[cls] if cls in self.count_by_class.keys() else 1e-12
            for cls in range(client_config["num_classes"])
        ]
        self.count_by_class_full = torch.tensor(temp).to(self.device)

        self.global_model2 = deepcopy(self.model)

    def get_model_named_parameters(self):
        return list(self.model.named_parameters())

    def _estimate_prototype(self, global_model2):
        self.model.eval()
        self.model.return_embedding = True
        embedding_dim = self.model.prototype.shape[1]
        prototype = torch.zeros_like(self.model.prototype)
        self.set_gloabl_param(self.global_model2, global_model2)
        self.global_model2.eval()
        self.global_model2.return_embedding = True
        with torch.no_grad():
            for i, (x, y) in enumerate(self.trainloader):
                x, y = x.to(self.device), y.to(self.device)
                feature_embedding, _ = self.model.forward(x)
                feature_embedding_global, _ = self.global_model2.forward(x)
                classes_shown_in_this_batch = torch.unique(y).cpu().numpy()
                for cls in classes_shown_in_this_batch:
                    mask = y == cls
                    feature_embedding_in_cls = torch.sum(
                        feature_embedding[mask, :], dim=0
                    )
                    feature_embedding_global_in_cls = torch.sum(
                        feature_embedding_global[mask, :], dim=0
                    )
                    prototype[cls] += (
                        0.7 * feature_embedding_in_cls
                        + 0.3 * feature_embedding_global_in_cls
                    )
        for cls in self.count_by_class.keys():
            prototype[cls] /= self.count_by_class[cls]
            prototype_cls_norm = torch.norm(prototype[cls]).clamp(min=1e-12)
            prototype[cls] = torch.div(prototype[cls], prototype_cls_norm)
            prototype[cls] *= self.count_by_class[cls]

        self.model.return_embedding = False

        to_share = {
            "scaled_prototype": prototype,
            "count_by_class_full": self.count_by_class_full,
        }
        return to_share

    def _estimate_prototype_adv(self):
        self.model.eval()
        self.model.return_embedding = True
        embeddings = []
        labels = []
        weights = []
        prototype = torch.zeros_like(self.model.prototype)
        with torch.no_grad():
            for i, (x, y) in enumerate(self.trainloader):
                # forward pass
                x, y = x.to(self.device), y.to(self.device)
                feature_embedding, logits = self.model.forward(x)
                prob_ = F.softmax(logits, dim=1)
                prob = torch.gather(prob_, dim=1, index=y.view(-1, 1))
                labels.append(y)
                weights.append(prob)
                embeddings.append(feature_embedding)
        self.model.return_embedding = False
        embeddings = torch.cat(embeddings, dim=0)
        labels = torch.cat(labels, dim=0)
        weights = torch.cat(weights, dim=0).view(-1, 1)
        for cls in self.count_by_class.keys():
            mask = labels == cls
            weights_in_cls = weights[mask, :]
            feature_embedding_in_cls = embeddings[mask, :]
            prototype[cls] = torch.sum(
                feature_embedding_in_cls * weights_in_cls, dim=0
            ) / torch.sum(weights_in_cls)
            prototype_cls_norm = torch.norm(prototype[cls]).clamp(min=1e-12)
            prototype[cls] = torch.div(prototype[cls], prototype_cls_norm)

        # calculate predictive power
        to_share = {
            "adv_agg_prototype": prototype,
            "count_by_class_full": self.count_by_class_full,
        }
        return to_share

    @staticmethod
    def _get_orthonormal_basis(m, n):
        """
        Each row of the the matrix is orthonormal
        """
        W = torch.rand(m, n)
        # gram schimdt
        for i in range(m):
            q = W[i, :]
            for j in range(i):
                q = q - torch.dot(W[j, :], W[i, :]) * W[j, :]
            if torch.equal(q, torch.zeros_like(q)):
                raise ValueError("The row vectors are not linearly independent!")
            q = q / torch.sqrt(torch.dot(q, q))
            W[i, :] = q
        return W

    def setup_seed_local(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    def _initialize_model(self):

        self.model = Conv2CifarNH(self.client_config).to(self.device)
        self.criterion = self.criterion.to(self.device)
        try:
            self.model.prototype.requires_grad_(False)
            if self.client_config["FedNH_head_init"] == "orthogonal":
                m, n = self.model.prototype.shape
                self.model.prototype.data = torch.nn.init.orthogonal_(
                    torch.rand(m, n)
                ).to(self.device)
            elif (
                self.client_config["FedNH_head_init"] == "uniform"
                and self.client_config["dim"] == 2
            ):
                r = 1.0
                num_cls = self.client_config["num_classes"]
                W = torch.zeros(num_cls, 2)
                for i in range(num_cls):
                    theta = i * 2 * torch.pi / num_cls
                    W[i, :] = torch.tensor([r * math.cos(theta), r * math.sin(theta)])
                self.model.prototype.copy_(W)
            else:
                raise NotImplementedError(
                    f"{self.client_config['FedNH_head_init']} + {self.client_config['num_classes']}d"
                )
        except AttributeError:
            raise NotImplementedError("Only support linear layers now.")
        if self.client_config["FedNH_fix_scaling"] == True:
            # 30.0 is a common choice in the paper
            self.model.scaling.requires_grad_(False)
            self.model.scaling.data = torch.tensor(30.0).to(self.device)
            print("self.model.scaling.data:", self.model.scaling.data)

    def set_gloabl_param(self, g1, g2):
        with torch.no_grad():
            for key in g2.keys():
                g1.state_dict()[key].copy_(g2[key])

    def training(self, round, num_epochs, global_model):
        """
        Note that in order to use the latest server side model the `set_params` method should be called before `training` method.
        """
        print("Begin local training!")
        train_start = time.time()
        self.setup_seed_local(round)
        # train mode
        self.model.train()
        # tracking stats
        self.set_gloabl_param(self.global_model, global_model)
        self.global_model = self.global_model.eval().requires_grad_(False)
        self.num_rounds_particiapted += 1
        loss_seq = []
        acc_seq = []
        if self.trainloader is None:
            raise ValueError("No trainloader is provided!")
        optimizer = setup_optimizer(self.model, self.client_config, round)
        for i in range(num_epochs):
            epoch_loss, correct = 0.0, 0
            for _, (x, y) in enumerate(self.trainloader):
                x, y = x.to(self.device), y.to(self.device)
                yhat = self.model.forward(x)
                loss = self.criterion(yhat, y)
                y_g = self.global_model.forward(x)
                loss += self._ntd_loss(yhat, y_g, y) * self.beta
                self.model.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    parameters=filter(
                        lambda p: p.requires_grad, self.model.parameters()
                    ),
                    max_norm=10,
                )
                optimizer.step()
                predicted = yhat.data.max(1)[1]
                correct += predicted.eq(y.data).sum().item()
                epoch_loss += loss.item() * x.shape[0]  # rescale to bacthsize

            epoch_loss /= len(self.trainloader.dataset)
            epoch_accuracy = correct / len(self.trainloader.dataset)
            loss_seq.append(epoch_loss)
            acc_seq.append(epoch_accuracy)
        self.new_state_dict = self.model.state_dict()
        self.train_loss_dict[round] = loss_seq
        self.train_acc_dict[round] = acc_seq
        print("Local training completed!")
        train_time = time.time() - train_start
        print(f"Local training time:{train_time:.3f} seconds")

    def get_train_loss_dict(self, r):
        return self.train_loss_dict[r]

    def get_train_acc_dict(self, r):
        return self.train_acc_dict[r]

    def get_test_loss_dict(self, r):
        return self.test_loss_dict[r]

    def get_test_acc_dict(self, r):
        return self.test_acc_dict[r]

    def get_num_train_samples(self):
        return self.num_train_samples

    def get_testloader(self):
        return self.testloader

    def _ntd_loss(self, logits, dg_logits, targets):
        """Not-tue Distillation Loss"""

        logits = refine_as_not_true(logits, targets, self.num_classes)
        pred_probs = F.log_softmax(logits / self.tau, dim=1)

        # Get smoothed global model prediction
        with torch.no_grad():
            dg_logits = refine_as_not_true(dg_logits, targets, self.num_classes)
            dg_probs = torch.softmax(dg_logits / self.tau, dim=1)

        loss = (self.tau**2) * self.KLDiv(pred_probs, dg_probs)

        return loss

    def upload(self, global_model2):
        if self.client_config["FedNH_client_adv_prototype_agg"]:
            return self.new_state_dict, self._estimate_prototype_adv()
        else:
            return self.new_state_dict, self._estimate_prototype(global_model2)

    def testing(self, round, testloader=None):
        self.model.eval()
        if testloader is None:
            testloader = self.testloader
        test_count_per_class = Counter(testloader.dataset.targets.numpy())
        num_classes = self.client_config["num_classes"]
        test_count_per_class = torch.tensor(
            [test_count_per_class[cls] * 1.0 for cls in range(num_classes)]
        )
        test_correct_per_class = torch.tensor([0] * num_classes)

        weight_per_class_dict = {
            "uniform": torch.tensor([1.0] * num_classes),
            "validclass": torch.tensor([0.0] * num_classes),
            "labeldist": torch.tensor([0.0] * num_classes),
        }
        for cls in self.label_dist.keys():
            weight_per_class_dict["labeldist"][cls] = self.label_dist[cls]
            weight_per_class_dict["validclass"][cls] = 1.0
        # start testing
        with torch.no_grad():
            for i, (x, y) in enumerate(testloader):
                # forward pass
                x, y = x.to(self.device), y.to(self.device)
                yhat = self.model.forward(x)
                # stats
                predicted = yhat.data.max(1)[1]
                classes_shown_in_this_batch = torch.unique(y).cpu().numpy()
                for cls in classes_shown_in_this_batch:
                    test_correct_per_class[cls] += (
                        ((predicted == y) * (y == cls)).sum().item()
                    )
        acc_by_critertia_dict = {}
        for k in weight_per_class_dict.keys():
            acc_by_critertia_dict[k] = (
                ((weight_per_class_dict[k] * test_correct_per_class).sum())
                / ((weight_per_class_dict[k] * test_count_per_class).sum())
            ).item()

        self.test_acc_dict[round] = {
            "acc_by_criteria": acc_by_critertia_dict,
            "correct_per_class": test_correct_per_class,
            "weight_per_class": weight_per_class_dict,
        }


def refine_as_not_true(logits, targets, num_classes):
    nt_positions = torch.arange(0, num_classes).to(logits.device)
    nt_positions = nt_positions.repeat(logits.size(0), 1)
    nt_positions = nt_positions[nt_positions[:, :] != targets.view(-1, 1)]
    nt_positions = nt_positions.view(-1, num_classes - 1)

    logits = torch.gather(logits, 1, nt_positions)

    return logits


@proxy(PYUObject)
class FedPCGServer(Server):
    def __init__(self, n_samples, server_config, clients_dict, exclude, **kwargs):
        super().__init__(server_config, clients_dict, **kwargs)

        self.device = "cpu"
        self.summary_setup()
        self.server_model_state_dict = deepcopy(self.clients_dict[0].get_params())
        self.server_side_client.set_params(
            self.server_model_state_dict.to(self.server_side_client.device),
            exclude_keys=set(),
        )
        self.exclude_layer_keys = set()
        for key in sf.reveal(self.server_model_state_dict):
            for ekey in exclude:
                if ekey in key:
                    self.exclude_layer_keys.add(key)
        if len(self.exclude_layer_keys) > 0:
            print(
                f"{self.server_config['strategy']}Server: the following keys will not be aggregated:\n ",
                self.exclude_layer_keys,
            )
        freeze_layers = []
        for param in sf.reveal(self.server_side_client.get_model_named_parameters()):
            if param[1].requires_grad == False:
                freeze_layers.append(param[0])
        if len(freeze_layers) > 0:
            print("Server: the following layers will not be updated:", freeze_layers)
        self.selection = ClusteredSampling2(server_config["num_clients"], "cpu", "L1")
        self.nsamples = n_samples
        self.selection.setup(self.nsamples)

    def aggregate(self, client_uploads, round):

        server_lr = self.server_config["learning_rate"] * (
            self.server_config["lr_decay_per_round"] ** (round - 1)
        )
        num_participants = len(client_uploads)
        update_direction_state_dict = None
        cumsum_per_class = torch.zeros(server_config["num_classes"])
        agg_weights_vec_dict = {}
        with torch.no_grad():
            for idx, (client_state_dict, prototype_dict) in enumerate(
                sf.reveal(client_uploads)
            ):
                if self.server_config["FedNH_server_adv_prototype_agg"] == False:
                    cumsum_per_class += prototype_dict["count_by_class_full"]
                else:
                    mu = prototype_dict["adv_agg_prototype"]
                    W = self.server_model_state_dict["prototype"]
                    agg_weights_vec_dict[idx] = torch.exp(
                        torch.sum(W * mu, dim=1, keepdim=True)
                    )
                client_update = linear_combination_state_dict(
                    sf.reveal(client_state_dict),
                    sf.reveal(self.server_model_state_dict),
                    1.0,
                    -1.0,
                    exclude=self.exclude_layer_keys,
                )
                if idx == 0:
                    update_direction_state_dict = client_update
                else:
                    update_direction_state_dict = linear_combination_state_dict(
                        sf.reveal(update_direction_state_dict),
                        sf.reveal(client_update),
                        1.0,
                        1.0,
                        exclude=self.exclude_layer_keys,
                    )
            # new feature extractor
            self.server_model_state_dict = linear_combination_state_dict(
                sf.reveal(self.server_model_state_dict),
                sf.reveal(update_direction_state_dict),
                1.0,
                server_lr / num_participants,
                exclude=self.exclude_layer_keys,
            )

            avg_prototype = torch.zeros_like(self.server_model_state_dict["prototype"])
            if self.server_config["FedNH_server_adv_prototype_agg"] == False:
                for _, prototype_dict in sf.reveal(client_uploads):
                    avg_prototype += prototype_dict[
                        "scaled_prototype"
                    ] / cumsum_per_class.view(-1, 1)

            else:
                m = self.server_model_state_dict["prototype"].shape[0]
                sum_of_weights = torch.zeros((m, 1)).to(avg_prototype.device)
                for idx, (_, prototype_dict) in enumerate(client_uploads):
                    sum_of_weights += agg_weights_vec_dict[idx]
                    avg_prototype += (
                        agg_weights_vec_dict[idx] * prototype_dict["adv_agg_prototype"]
                    )
                avg_prototype /= sum_of_weights

            avg_prototype = F.normalize(avg_prototype, dim=1)
            weight = self.server_config["FedNH_smoothing"]
            temp = (
                weight * self.server_model_state_dict["prototype"]
                + (1 - weight) * avg_prototype
            )
            self.server_model_state_dict["prototype"].copy_(F.normalize(temp, dim=1))

    def testing(self, round, active_only=True, **kwargs):
        """
        active_only: only compute statiscs with to the active clients only
        """
        self.server_side_client.set_params(
            self.server_model_state_dict, self.exclude_layer_keys
        )
        self.server_side_client.testing(
            round, testloader=None
        )  # use global testdataset
        print(
            " server global model correct",
            torch.sum(
                sf.reveal(self.server_side_client.get_test_acc_dict(round))[
                    "correct_per_class"
                ]
            ).item(),
        )
        client_indices = self.clients_dict.keys()
        if active_only:
            client_indices = self.active_clients_indicies
        for cid in client_indices:
            client = self.clients_dict[cid]
            if self.server_config["split_testset"] == True:
                client.testing(round, None)
            else:
                client.testing(
                    round, self.server_side_client.get_testloader().to(client.device)
                )

    def setup_seed_global(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    def collect_stats(self, stage, round, active_only, **kwargs):
        """
        No actual training and testing is performed. Just collect stats.
        stage: str;
            {"train", "test"}
        active_only: bool;
            True: compute stats on active clients only
            False: compute stats on all clients
        """
        client_indices = self.clients_dict.keys()
        if active_only:
            client_indices = self.active_clients_indicies
        total_loss = 0.0
        total_acc = 0.0
        total_samples = 0
        if stage == "train":
            for cid in client_indices:
                client = self.clients_dict[cid]
                loss, acc, num_samples = (
                    sf.reveal(client.get_train_loss_dict(round))[-1],
                    sf.reveal(client.get_train_acc_dict(round))[-1],
                    client.get_num_train_samples(),
                )
                total_loss += loss * sf.reveal(num_samples)
                total_acc += acc * sf.reveal(num_samples)
                total_samples += sf.reveal(num_samples)
            average_loss, average_acc = (
                total_loss / total_samples,
                total_acc / total_samples,
            )
            self.average_train_loss_dict[round] = average_loss
            self.average_train_acc_dict[round] = average_acc
        else:
            self.gfl_test_acc_dict[round] = self.server_side_client.get_test_acc_dict(
                round
            )
            acc_criteria = sf.reveal(self.server_side_client.get_test_acc_dict(round))[
                "acc_by_criteria"
            ].keys()
            self.average_pfl_test_acc_dict[round] = {key: 0.0 for key in acc_criteria}
            for cid in client_indices:
                client = self.clients_dict[cid]
                acc_by_criteria_dict = sf.reveal(client.get_test_acc_dict(round))[
                    "acc_by_criteria"
                ]
                for key in acc_criteria:
                    self.average_pfl_test_acc_dict[round][key] += acc_by_criteria_dict[
                        key
                    ]

            num_participants = len(client_indices)
            for key in acc_criteria:
                self.average_pfl_test_acc_dict[round][key] /= num_participants

    def client_selection(self):
        client_indices = [*range(self.server_config["num_clients"])]
        n = int(
            self.server_config["num_clients"] * self.server_config["participate_ratio"]
        )
        selected_client_indices = self.selection.select(n, client_indices)
        return selected_client_indices

    def run(self, device, **kwargs):
        if self.server_config["use_tqdm"]:
            round_iterator = tqdm(
                range(self.rounds + 1, self.server_config["num_rounds"] + 1),
                desc="Round Progress",
            )
        else:
            round_iterator = range(
                self.rounds + 1, self.server_config["num_rounds"] + 1
            )
        for r in round_iterator:
            self.setup_seed_global(r)
            if r == 1:
                selected_indices = self.select_clients(
                    self.server_config["participate_ratio"]
                )
            else:
                selected_indices = self.client_selection()
            if self.server_config["drop_ratio"] > 0:
                self.active_clients_indicies = np.random.choice(
                    selected_indices,
                    int(len(selected_indices) * (1 - self.server_config["drop_ratio"])),
                    replace=False,
                )
            else:
                self.active_clients_indicies = selected_indices
            tqdm.write(f"Round:{r} - Active clients:{self.active_clients_indicies}:")
            for cid in self.active_clients_indicies:
                client = self.clients_dict[cid]
                client.set_params(
                    sf.reveal(self.server_model_state_dict), self.exclude_layer_keys
                )

            client_uploads = []
            for cid in self.active_clients_indicies:
                client = self.clients_dict[cid]
                client.training(
                    r,
                    client_config["num_epochs"],
                    sf.reveal(self.server_model_state_dict),
                )
                client_upload = client.upload(sf.reveal(self.server_model_state_dict))
                client_uploads.append(client_upload.to(device))

            local_models = [
                self.clients_dict[cid].get_model().to(device)
                for cid in range(self.server_config["num_clients"])
            ]
            self.selection.init(self.server_model_state_dict, local_models)

            self.collect_stats(stage="train", round=r, active_only=True)

            self.aggregate(client_uploads, round=r)

            if (r - 1) % self.server_config["test_every"] == 0:
                test_start = time.time()
                self.testing(round=r, active_only=True)
                test_time = time.time() - test_start
                print(f" Testing time:{test_time:.3f} seconds")
                self.collect_stats(stage="test", round=r, active_only=True)
                print(
                    " avg_test_acc:",
                    sf.reveal(self.gfl_test_acc_dict[r])["acc_by_criteria"],
                )
                print(" pfl_avg_test_acc:", self.average_pfl_test_acc_dict[r])
                if len(self.gfl_test_acc_dict) >= 2:
                    current_key = r
                    if (
                        sf.reveal(self.gfl_test_acc_dict[current_key])[
                            "acc_by_criteria"
                        ]["uniform"]
                        > best_test_acc
                    ):
                        best_test_acc = sf.reveal(self.gfl_test_acc_dict[current_key])[
                            "acc_by_criteria"
                        ]["uniform"]
                        self.server_model_state_dict_best_so_far = deepcopy(
                            self.server_model_state_dict
                        )
                        tqdm.write(
                            f" Best test accuracy:{float(best_test_acc):5.3f}. Best server model is updatded and saved at {kwargs['filename']}!"
                        )
                        if "filename" in kwargs:
                            torch.save(
                                sf.reveal(self.server_model_state_dict_best_so_far),
                                kwargs["filename"],
                            )
                else:
                    best_test_acc = sf.reveal(self.gfl_test_acc_dict[r])[
                        "acc_by_criteria"
                    ]["uniform"]
            if kwargs["use_wandb"]:
                stats = {
                    "avg_train_loss": self.average_train_loss_dict[r],
                    "avg_train_acc": self.average_train_acc_dict[r],
                    "gfl_test_acc_uniform": self.gfl_test_acc_dict[r][
                        "acc_by_criteria"
                    ]["uniform"],
                }

                for criteria in self.average_pfl_test_acc_dict[r].keys():
                    stats[f"pfl_test_acc_{criteria}"] = self.average_pfl_test_acc_dict[
                        r
                    ][criteria]

                wandb.log(stats)

隐语环境下FedPCG的实现流程，这里只初始化5个客户端为例，可以修改config.ini中的客户端数量和采样率，提高模型的表现。

In [None]:
import secretflow as sf

sf.shutdown()
num_clients = server_config["num_clients"]
sf.init(
    [f"client_{i}" for i in range(server_config["num_clients"])]
    + ["server"]
    + ["fake_client"],
    address="local",
)

if args.strategy == "FedAvg":
    ClientCstr, ServerCstr = FedAvgClient, FedAvgServer

elif args.strategy == "FedNH":
    ClientCstr, ServerCstr = FedNHClient, FedNHServer
    server_config["FedNH_smoothing"] = args.FedNH_smoothing
    server_config["FedNH_server_adv_prototype_agg"] = (
        args.FedNH_server_adv_prototype_agg
    )
    client_config["FedNH_client_adv_prototype_agg"] = (
        args.FedNH_client_adv_prototype_agg
    )

elif args.strategy == "FedPCG":
    ClientCstr, ServerCstr = FedPCGClient, FedPCGServer
    server_config["FedNH_smoothing"] = args.FedNH_smoothing
    server_config["FedNH_server_adv_prototype_agg"] = (
        args.FedNH_server_adv_prototype_agg
    )
    client_config["FedNH_client_adv_prototype_agg"] = (
        args.FedNH_client_adv_prototype_agg
    )
else:
    raise ValueError("Invalid strategy!")

directory = f"./{args.purpose}_{server_config['strategy']}/"
mkdirs(directory)
path = directory
print("results are saved in: ", path)
client_config_lst = [client_config for i in range(args.num_clients)]
criterion = nn.CrossEntropyLoss()

trainset, testset, _ = get_datasets(server_config["dataset"])
client_pyus = [sf.PYU(f"client_{i}") for i in range(num_clients)]
server_pyu = sf.PYU("server")
# setup clients
if server_config["split_testset"] == False:
    clients_dict, n_samples = setup_clients(
        ClientCstr,
        trainset,
        None,
        criterion,
        client_config_lst,
        client_pyus,
        server_config=server_config,
        beta=server_config["beta"],
        num_classes_per_client=server_config["num_classes_per_client"],
        num_shards_per_client=server_config["num_shards_per_client"],
    )
else:
    print("split test set!")
    clients_dict, n_samples = setup_clients(
        ClientCstr,
        trainset,
        testset,
        criterion,
        client_config_lst,
        client_pyus,
        server_config=server_config,
        beta=server_config["beta"],
        num_classes_per_client=server_config["num_classes_per_client"],
        num_shards_per_client=server_config["num_shards_per_client"],
        same_testset=False,
    )

if args.strategy != "Local":
    print("ClientCstr", ClientCstr)
    server = FedPCGServer(
        n_samples=n_samples,
        device=server_pyu,
        server_config=server_config,
        clients_dict=clients_dict,
        exclude=server_config["exclude"],
        server_side_criterion=criterion,
        global_testset=testset,
        global_trainset=trainset,
        client_cstr=ClientCstr,
        server_side_client_config=client_config,
        server_side_client_device=args.device,
    )
    print("Strategy Related Hyper-parameters:")
    print("server side")
    for k in server_config.keys():
        if args.strategy in k:
            print(" ", k, ":", server_config[k])
    print("client side")
    for k in client_config.keys():
        if args.strategy in k:
            print(" ", k, ":", client_config[k])
    server.run(
        device=server.device,
        filename=path + "_best_global_model.pkl",
        use_wandb=use_wandb,
        global_seed=args.global_seed,
    )
else:
    expected_num_rounds = int(
        server_config["num_rounds"] * server_config["participate_ratio"]
    )
    init_weight = clients_dict[0].get_params()
    global_testloader = DataLoader(testset, batch_size=128, shuffle=False)
    for cid in clients_dict.keys():
        print(f"Progress:{cid}/{len(clients_dict)}")
        client = clients_dict[cid]
        client.set_params(init_weight, exclude_keys=set())
        for r in range(1, expected_num_rounds + 1):
            setup_seed(r + args.global_seed)
            client.training(r, client.client_config["num_epochs"])
            client.testing(r, global_testloader)
            print(
                f"Round: {r}/{expected_num_rounds}",
                client.test_acc_dict[r]["acc_by_criteria"],
            )
        client.model = None
        client.trainloader = None
        client.trainset = None
        client.new_state_dict = None
    save_to_pkl(clients_dict, path + "_final_clients_dict.pkl")