该Notebook在隐语环境中实现了论文FedProto: Federated Prototype Learning across Heterogeneous Clients的数据划分和联邦学习方法。

In [77]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.utils.model_zoo as model_zoo
from tqdm import tqdm
import copy, sys
import time
import numpy as np
from torchvision import datasets, transforms
import random
from secretflow import PYUObject, proxy


一些超参数设置，可以通过同级目录下的config.ini设置对应不同数据集的超参数

In [78]:
import argparse
import configparser

def args_parser():
    parser = argparse.ArgumentParser()

    # federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--rounds', type=int, default=100,
                        help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=20,
                        help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.04,
                        help='the fraction of clients: C')
    parser.add_argument('--train_ep', type=int, default=1,
                        help="the number of local episodes: E")
    parser.add_argument('--local_bs', type=int, default=4,
                        help="local batch size: B")
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.5,
                        help='SGD momentum (default: 0.5)')

    # model arguments
    parser.add_argument('--model', type=str, default='cnn', help='model name')
    parser.add_argument('--alg', type=str, default='fedproto', help="algorithms")
    parser.add_argument('--mode', type=str, default='task_heter', help="mode")
    parser.add_argument('--num_channels', type=int, default=1, help="number \
                        of channels of imgs")
    parser.add_argument('--norm', type=str, default='batch_norm',
                        help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32,
                        help="number of filters for conv nets -- 32 for \
                        mini-imagenet, 64 for omiglot.")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than \
                        strided convolutions")

    # other arguments
    parser.add_argument('--data_dir', type=str, default='../data/', help="directory of dataset")
    parser.add_argument('--dataset', type=str, default='mnist', help="name \
                        of dataset")
    parser.add_argument('--num_classes', type=int, default=10, help="number \
                        of classes")
    parser.add_argument('--device', default='cpu', help="To use cuda")
    parser.add_argument('--gpu', default=0, help="To use cuda, set \
                        to a specific GPU ID. Default set to use CPU.")
    parser.add_argument('--optimizer', type=str, default='sgd', help="type \
                        of optimizer")
    parser.add_argument('--iid', type=int, default=0,
                        help='Default set to IID. Set to 0 for non-IID.')
    parser.add_argument('--unequal', type=int, default=0,
                        help='whether to use unequal data splits for  \
                        non-i.i.d setting (use 0 for equal splits)')
    parser.add_argument('--stopping_rounds', type=int, default=10,
                        help='rounds of early stopping')
    parser.add_argument('--verbose', type=int, default=1, help='verbose')
    parser.add_argument('--seed', type=int, default=1234, help='random seed')
    parser.add_argument('--test_ep', type=int, default=10, help="num of test episodes for evaluation")

    # Local arguments
    parser.add_argument('--ways', type=int, default=3, help="num of classes")
    parser.add_argument('--shots', type=int, default=100, help="num of shots")
    parser.add_argument('--train_shots_max', type=int, default=110, help="num of shots")
    parser.add_argument('--test_shots', type=int, default=15, help="num of shots")
    parser.add_argument('--stdev', type=int, default=2, help="stdev of ways")
    parser.add_argument('--ld', type=float, default=1, help="weight of proto loss")
    parser.add_argument('--ft_round', type=int, default=10, help="round of fine tuning")
    arg_list = None
    
    config = configparser.ConfigParser()
    config.read('config.ini')
    # 其实是个字典: 
    # print(config['train']['batch_size'])
    arg_list = []
    for k, v in config['train'].items():
        arg_list.append("--"+k)
        arg_list.append(v)

    args = parser.parse_args(arg_list)
    return args
args = args_parser()
print(args)

Namespace(alg='fedproto', data_dir='../data/', dataset='mnist', device='cpu', frac=0.04, ft_round=10, gpu=0, iid=0, ld=1, local_bs=4, lr=0.01, max_pool='True', mode='task_heter', model='cnn', momentum=0.5, norm='batch_norm', num_channels=1, num_classes=10, num_filters=32, num_users=8, optimizer='sgd', rounds=100, seed=1234, shots=100, stdev=2, stopping_rounds=10, test_ep=10, test_shots=15, train_ep=1, train_shots_max=110, unequal=0, verbose=1, ways=3)


In [79]:
def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.rounds}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.train_ep}\n')
    return

In [80]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

客户端的本地模型结构

In [81]:
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(int(320/20*20), 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x1 = F.relu(self.fc1(x))
        x = F.dropout(x1, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1), x1

客户端类，主要包含了初始化、数据集划分和本地更新等主要函数，本地更新的loss由交叉熵loss和本地原型、全局原型之间的mse_loss之和组成。

In [82]:
@proxy(PYUObject)
class Client(object):
    def __init__(self, args, dataset, idxs):
        self.args = args
        self.trainloader = self.train_val_test(dataset, list(idxs))
#         self.device = args.device
        self.device = 'cpu'
        self.criterion = nn.NLLLoss().to(self.device)
        self.model_param = None
        self.loss = None
        self.acc = None
        self.protos = None
        self.local_loss = None
        self.proto_loss = None
        self.global_protos = []
        self.model = CNNMnist(args=args).to(self.device)
        
    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        idxs_train = idxs[:int(1 * len(idxs))]
        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        return trainloader


    def update_weights_het(self, args, idx, global_round=round):
        # Set mode to train model
        self.model.train()
        epoch_loss = {'total':[],'1':[], '2':[], '3':[]}

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr,
                                        momentum=0.5)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4)

        for iter in range(self.args.train_ep):
            batch_loss = {'total':[],'1':[], '2':[], '3':[]}
            agg_protos_label = {}
            for batch_idx, (images, label_g) in enumerate(self.trainloader):
                images, labels = images.to(self.device), label_g.to(self.device)

                # loss1: cross-entrophy loss, loss2: proto distance loss
                self.model.zero_grad()
                log_probs, protos = self.model(images)
                loss1 = self.criterion(log_probs, labels)

                loss_mse = nn.MSELoss()
                if len(self.global_protos) == 0:
                    loss2 = 0*loss1
                else:
                    proto_new = copy.deepcopy(protos.data)
                    i = 0
                    for label in labels:
                        if label.item() in self.global_protos.keys():
                            proto_new[i, :] = self.global_protos[label.item()][0].data
                        i += 1
                    loss2 = loss_mse(proto_new, protos)

                loss = loss1 + loss2 * args.ld
                loss.backward()
                optimizer.step()

                for i in range(len(labels)):
                    if label_g[i].item() in agg_protos_label:
                        agg_protos_label[label_g[i].item()].append(protos[i,:])
                    else:
                        agg_protos_label[label_g[i].item()] = [protos[i,:]]

                log_probs = log_probs[:, 0:args.num_classes]
                _, y_hat = log_probs.max(1)
                acc_val = torch.eq(y_hat, labels.squeeze()).float().mean()

                if self.args.verbose and (batch_idx % 10 == 0):
                    print('| Global Round : {} | User: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.3f} | Acc: {:.3f}'.format(
                        global_round, idx, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader),
                        loss.item(),
                        acc_val.item()))
                batch_loss['total'].append(loss.item())
                batch_loss['1'].append(loss1.item())
                batch_loss['2'].append(loss2.item())
            epoch_loss['total'].append(sum(batch_loss['total'])/len(batch_loss['total']))
            epoch_loss['1'].append(sum(batch_loss['1']) / len(batch_loss['1']))
            epoch_loss['2'].append(sum(batch_loss['2']) / len(batch_loss['2']))

        epoch_loss['total'] = sum(epoch_loss['total']) / len(epoch_loss['total'])
        epoch_loss['1'] = sum(epoch_loss['1']) / len(epoch_loss['1'])
        epoch_loss['2'] = sum(epoch_loss['2']) / len(epoch_loss['2'])
        
        self.set_param_model(self.model.state_dict())
        self.set_loss(epoch_loss)
        self.set_acc(acc_val.item())
        self.set_protos(agg_protos_label)

#         return (model.state_dict(), epoch_loss, acc_val.item(), agg_protos_label)
    
    def set_param_model(self,param):
        self.model_param = param
    def get_param_model(self):
        return self.model_param

    def set_loss(self,loss):
        self.loss = loss
        self.local_loss = loss['total']
        self.proto_loss = loss['2']                      
    def get_loss(self):  
        return self.loss
    def get_local_loss(self):
        return self.local_loss
    def get_proto_loss(self):
        return self.proto_loss
    def set_acc(self,acc):
        self.acc = acc
    def get_acc(self):
        return self.acc
    def set_protos(self,protos):
        self.protos = protos
    def get_protos(self):
        return self.protos
    
    def set_global_protos(self,global_protos):
        
        self.global_protos=global_protos
       
        
    def get_weights(self):
        return self.model.state_dict()

    def set_weights(self, weights):
        self.model.load_state_dict(weights,strict=True)
    def agg_func(self,protos):
        """
        Returns the average of the weights.
        """

        for [label, proto_list] in protos.items():
            if len(proto_list) > 1:
                proto = 0 * proto_list[0].data
                for i in proto_list:
                    proto += i.data
                protos[label] = proto / len(proto_list)
            else:
                protos[label] = proto_list[0]

        return protos


服务器类，主要包含了初始化、全局原型聚合和测试全局模型、保存原型等函数。

In [83]:
@proxy(PYUObject)
class Server(object):
    def __init__(self,args):
        self.args = args
        self.device = 'cpu'
        self.acc_list_l = []
        self.acc_list_g = []
        self.loss_list = []
        self.model = CNNMnist(args=args).to(self.device)
    def proto_aggregation(self,local_protos_list):
        agg_protos_label = dict()
        for idx in local_protos_list:
            local_protos = local_protos_list[idx]
            for label in local_protos.keys():
                if label in agg_protos_label:
                    agg_protos_label[label].append(local_protos[label])
                else:
                    agg_protos_label[label] = [local_protos[label]]

        for [label, proto_list] in agg_protos_label.items():
            if len(proto_list) > 1:
                proto = 0 * proto_list[0].data
                for i in proto_list:
                    proto += i.data
                agg_protos_label[label] = [proto / len(proto_list)]
            else:
                agg_protos_label[label] = [proto_list[0].data]

        return agg_protos_label
    
    def test_inference_new_het_lt(self,args, local_weights_list_global,test_dataset, classes_list, user_groups_gt, global_protos=[]):
        """ Returns the test accuracy and loss.
        """
        loss, total, correct = 0.0, 0.0, 0.0
        loss_mse = nn.MSELoss()
        criterion = nn.NLLLoss().to(self.device)

#         acc_list_g = []
#         acc_list_l = []
#         loss_list = []
        for idx in range(args.num_users):
#             model = local_model_list[idx]
#             model.to(self.device)
            self.model.load_state_dict(local_weights_list_global[idx],strict=True)
            testloader = DataLoader(DatasetSplit(test_dataset, user_groups_gt[idx]), batch_size=64, shuffle=True)

            # test (local model)
            self.model.eval()
            for batch_idx, (images, labels) in enumerate(testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                self.model.zero_grad()
                outputs, protos = self.model(images)

                batch_loss = criterion(outputs, labels)
                loss += batch_loss.item()

                # prediction
                _, pred_labels = torch.max(outputs, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels)).item()
                total += len(labels)

            acc = correct / total
            print('| User: {} | Global Test Acc w/o protos: {:.3f}'.format(idx, acc))
            self.acc_list_l.append(acc)

            # test (use global proto)
            if global_protos!=[]:
                for batch_idx, (images, labels) in enumerate(testloader):
                    images, labels = images.to(self.device), labels.to(self.device)
                    self.model.zero_grad()
                    outputs, protos = self.model(images)

                    # compute the dist between protos and global_protos
                    a_large_num = 100
                    dist = a_large_num * torch.ones(size=(images.shape[0], args.num_classes)).to(self.device)  # initialize a distance matrix
                    for i in range(images.shape[0]):
                        for j in range(args.num_classes):
                            if j in global_protos.keys() and j in classes_list[idx]:
                                d = loss_mse(protos[i, :], global_protos[j][0])
                                dist[i, j] = d

                    # prediction
                    _, pred_labels = torch.min(dist, 1)
                    pred_labels = pred_labels.view(-1)
                    correct += torch.sum(torch.eq(pred_labels, labels)).item()
                    total += len(labels)

                    # compute loss
                    proto_new = copy.deepcopy(protos.data)
                    i = 0
                    for label in labels:
                        if label.item() in global_protos.keys():
                            proto_new[i, :] = global_protos[label.item()][0].data
                        i += 1
                    loss2 = loss_mse(proto_new, protos)
                    if self.device == 'cuda':
                        loss2 = loss2.cpu().detach().numpy()
                    else:
                        loss2 = loss2.detach().numpy()

                acc = correct / total
                print('| User: {} | Global Test Acc with protos: {:.5f}'.format(idx, acc))
                self.acc_list_g.append(acc)
                self.loss_list.append(loss2)
#         self.set_acc_list_l(acc_list_l)
#         self.set_acc_list_g(acc_list_g)
#         self.set_loss_list(loss_list)
#         return acc_list_l, acc_list_g, loss_list
    
#     def set_acc_list_l(self,acc_list_l):
#         self.acc_list_l = acc_list_l
    def get_acc_list_l(self):
        return self.acc_list_l
#     def set_acc_list_g(self,acc_list_g):
#         self.acc_list_g = acc_list_g
    def get_acc_list_g(self):
        return self.acc_list_g
#     def set_loss_list(self,loss_list):
#         self.loss_list = loss_list
    def get_loss_list(self):
        return self.loss_list
    def save_protos(self,args, test_dataset, user_groups_gt):
        """ Returns the test accuracy and loss.
        """
        loss, total, correct = 0.0, 0.0, 0.0

        device = self.args.device
        criterion = nn.NLLLoss().to(device)

        agg_protos_label = {}
        for idx in range(self.args.num_users):
            agg_protos_label[idx] = {}
#             model = local_model_list[idx]
#             model.to(self.args.device)
            testloader = DataLoader(DatasetSplit(test_dataset, user_groups_gt[idx]), batch_size=64, shuffle=True)

            self.model.eval()
            for batch_idx, (images, labels) in enumerate(testloader):
                images, labels = images.to(device), labels.to(device)

                self.model.zero_grad()
                outputs, protos = self.model(images)

                batch_loss = criterion(outputs, labels)
                loss += batch_loss.item()

                # prediction
                _, pred_labels = torch.max(outputs, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels)).item()
                total += len(labels)

                for i in range(len(labels)):
                    if labels[i].item() in agg_protos_label[idx]:
                        agg_protos_label[idx][labels[i].item()].append(protos[i, :])
                    else:
                        agg_protos_label[idx][labels[i].item()] = [protos[i, :]]

        x = []
        y = []
        d = []
        for i in range(self.args.num_users):
            for label in agg_protos_label[i].keys():
                for proto in agg_protos_label[i][label]:
                    if args.device == 'cuda':
                        tmp = proto.cpu().detach().numpy()
                    else:
                        tmp = proto.detach().numpy()
                    x.append(tmp)
                    y.append(label)
                    d.append(i)

        x = np.array(x)
        y = np.array(y)
        d = np.array(d)
        np.save('./' + args.alg + '_protos.npy', x)
        np.save('./' + args.alg + '_labels.npy', y)
        np.save('./' + args.alg + '_idx.npy', d)

        print("Save protos and labels successfully.")

多种数据集的划分方式

In [84]:
def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

def mnist_noniid(args, dataset, num_users, n_list, k_list):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """

    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 10, 6000
    idx_shard = [i for i in range(num_shards)]
    dict_users = {}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()
    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]
    label_begin = {}
    cnt=0
    for i in idxs_labels[1,:]:
        if i not in label_begin:
                label_begin[i] = cnt
        cnt+=1

    classes_list = []
    for i in range(num_users):
        n = n_list[i]
        k = k_list[i]
        k_len = args.train_shots_max
        classes = random.sample(range(0,args.num_classes), n)
        classes = np.sort(classes)
        print("user {:d}: {:d}-way {:d}-shot".format(i + 1, n, k))
        print("classes:", classes)
        user_data = np.array([])
        for each_class in classes:
            # begin = i*10 + label_begin[each_class.item()]
            begin = i * k_len + label_begin[each_class.item()]
            user_data = np.concatenate((user_data, idxs[begin : begin+k]),axis=0)
        dict_users[i] = user_data
        classes_list.append(classes)

    return dict_users, classes_list

def mnist_noniid_lt(args, test_dataset, num_users, n_list, k_list, classes_list):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """

    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 10, 1000
    idx_shard = [i for i in range(num_shards)]
    dict_users = {}
    idxs = np.arange(num_shards*num_imgs)
    labels = test_dataset.train_labels.numpy()
    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]
    label_begin = {}
    cnt=0
    for i in idxs_labels[1,:]:
        if i not in label_begin:
                label_begin[i] = cnt
        cnt+=1

    for i in range(num_users):
        k = 40 # 每个类选多少张做测试
        classes = classes_list[i]
        print("local test classes:", classes)
        user_data = np.array([])
        for each_class in classes:
            begin = i*40 + label_begin[each_class.item()]
            user_data = np.concatenate((user_data, idxs[begin : begin+k]),axis=0)
        dict_users[i] = user_data


    return dict_users

def mnist_noniid_unequal(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset s.t clients
    have unequal amount of data
    :param dataset:
    :param num_users:
    :returns a dict of clients with each clients assigned certain
    number of training imgs
    """
    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
    num_shards, num_imgs = 1200, 50
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # Minimum and maximum shards assigned per client:
    min_shard = 1
    max_shard = 30

    # Divide the shards into random chunks for every client
    # s.t the sum of these chunks = num_shards
    random_shard_size = np.random.randint(min_shard, max_shard+1,
                                          size=num_users)
    random_shard_size = np.around(random_shard_size /
                                  sum(random_shard_size) * num_shards)
    random_shard_size = random_shard_size.astype(int)

    # Assign the shards randomly to each client
    if sum(random_shard_size) > num_shards:

        for i in range(num_users):
            # First assign each client 1 shard to ensure every client has
            # atleast one shard of data
            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        random_shard_size = random_shard_size-1

        # Next, randomly assign the remaining shards
        for i in range(num_users):
            if len(idx_shard) == 0:
                continue
            shard_size = random_shard_size[i]
            if shard_size > len(idx_shard):
                shard_size = len(idx_shard)
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)
    else:

        for i in range(num_users):
            shard_size = random_shard_size[i]
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        if len(idx_shard) > 0:
            # Add the leftover shards to the client with minimum images:
            shard_size = len(idx_shard)
            # Add the remaining shard to the client with lowest data
            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            for rand in rand_set:
                dict_users[k] = np.concatenate(
                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

    return dict_users

数据集下载

In [85]:
def get_dataset(args, n_list, k_list):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    data_dir = args.data_dir + args.dataset
    if args.dataset == 'mnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(args, train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups, classes_list = mnist_noniid(args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = mnist_noniid_lt(args, test_dataset, args.num_users, n_list, k_list, classes_list)
                classes_list_gt = classes_list

    elif args.dataset == 'femnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = femnist.FEMNIST(args, data_dir, train=True, download=True,
                                        transform=apply_transform)
        test_dataset = femnist.FEMNIST(args, data_dir, train=False, download=True,
                                       transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = femnist_iid(train_dataset, args.num_users)
            # print("TBD")
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                # user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
                user_groups = femnist_noniid_unequal(args, train_dataset, args.num_users)
                # print("TBD")
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = femnist_noniid(args, args.num_users, n_list, k_list)
                user_groups_lt = femnist_noniid_lt(args, args.num_users, classes_list)

    elif args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=trans_cifar10_train)
        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, transform=trans_cifar10_val)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = cifar10_noniid(args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = cifar10_noniid_lt(args, test_dataset, args.num_users, n_list, k_list, classes_list)


    return train_dataset, test_dataset, user_groups, user_groups_lt, classes_list, classes_list_gt




各个客户端完成本地更新，将更新后的本地原型上传服务器，服务器完成全局原型聚合，最后对全局模型进行测试。主要是需要注意server和clients之间的各数据交互过程。

In [86]:
def calculate_avg(total,length):
    return total / length
def add_and_div(loss,length):
    return sum(loss) / length
def mean(list_a):
    return np.mean(list_a)
def std(list_a):
    return np.std(list_a)
def FedProto_taskheter(args, train_dataset, test_dataset, user_groups, user_groups_lt, classes_list, clients, server,server_pyu):

    global_protos = []
    idxs_users = np.arange(args.num_users)

    train_loss, train_accuracy = [], []
    local_model_list = []
    for round in tqdm(range(args.rounds)):
        local_weights,local_weights_global, local_losses, local_protos = [], [],[], {}
        print(f'\n | Global Training Round : {round + 1} |\n')
        for idx,local_model in enumerate(clients):    
            
            local_model.update_weights_het(args, idx, global_round=round)
            w = local_model.get_param_model()
            loss = local_model.get_loss()
            acc = local_model.get_acc()
            protos = local_model.get_protos()
            agg_protos = local_model.agg_func(protos).to(server.device)
#             agg_protos = {1:[512,1,1],7:[512,1,1],0:[512,1,1],9:[512,1,1],}
            local_weights.append(copy.deepcopy(w))
            local_weights_global.append(copy.deepcopy(w.to(server.device)))
            local_losses.append(copy.deepcopy(local_model.get_local_loss()).to(server.device))
            local_protos[idx] = agg_protos          

        local_weights_list = local_weights
        local_weights_list_global =local_weights_global

        for idx,local_model in enumerate(clients):
            local_model.set_weights(local_weights_list[idx])
#             local_model_list.append(local_model.to(server.device))
        # update global weights dict:{10}={1:[512,1,1],...}
#         setting = []
        global_protos = server.proto_aggregation(local_protos)
        for local_model in clients:
            local_protos = global_protos.to(local_model.device)
            local_model.set_global_protos(local_protos)
            
#             setting.append(ret)
#         sf.wait(setting)
        loss_avg = server_pyu(add_and_div)(local_losses,len(local_losses))
        train_loss.append(loss_avg)
    global_protos = global_protos.to(server.device)
#     local_weights_list=local_weights_list.to(server.device)
    server.test_inference_new_het_lt(args, local_weights_list_global,test_dataset, classes_list, user_groups_lt, global_protos)
    acc_list_l = server.get_acc_list_l()
    acc_list_g = server.get_acc_list_g()
    loss_list = server.get_loss_list()
    print('For all users (with protos), mean of test acc is ',sf.reveal(server_pyu(mean)(acc_list_g)), 'std of test acc is ',sf.reveal(server_pyu(std)(acc_list_g)))
    print('For all users (w/o protos), mean of test acc is ',sf.reveal(server_pyu(mean)(acc_list_l)), 'std of test acc is ', sf.reveal(server_pyu(std)(acc_list_l)))
    print('For all users (with protos), mean of proto loss is ', sf.reveal(server_pyu(mean)(loss_list)),'std of test acc is ', sf.reveal(server_pyu(std)(loss_list)))

    # save protos
#     if args.dataset == 'mnist':
#         server.save_protos(args, test_dataset, user_groups_lt)

在隐语平台实现FedProto，主要是一些初始化的内容，包含数据集和客户端、服务器的初始化，最后执行FedProto_taskheter

In [87]:
import secretflow as sf


start_time = time.time()

args = args_parser()
exp_details(args)

# set random seeds
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.device == 'cuda':
    torch.cuda.set_device(args.gpu)
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)
else:
    torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print(args.device)
print(torch.cuda.is_available())
n_list = np.random.randint(max(2, args.ways - args.stdev), min(args.num_classes, args.ways + args.stdev + 1), args.num_users)
if args.dataset == 'mnist':
    k_list = np.random.randint(args.shots - args.stdev + 1 , args.shots + args.stdev - 1, args.num_users)
train_dataset, test_dataset, user_groups, user_groups_lt, classes_list, classes_list_gt = get_dataset(args, n_list, k_list)
# local_model_list = []
# for i in range(args.num_users):
#     if args.dataset == 'mnist':
#         args.out_channels = 20
#         local_model = CNNMnist(args=args)
#         local_model.to(args.device)
#         local_model.train()
#         local_model_list.append(local_model)
sf.shutdown()
# clients_list_init= []
# for i in np.arange(args.num_users):
#     client_i_init = "client"+str(i+1)
#     clients_list_init.append(client_i_init)
# clients_list_init.append("server")
# sf.init(["client_1", "client_2", "client_3", "client_4", "client_5", "client_6", "client_7", "client_8", "client_9", "client_10", "client_11", "client_12", "client_13", "client_14", "client_15", "client_16", "client_17", "client_18", "client_19", "client_20", "server"], address='local', num_gpus=1)
sf.init(["client_1", "client_2", "client_3", "client_4", "client_5", "client_6", "client_7", "client_8","server"], address='local', num_gpus=1)
clients = []
for i in np.arange(args.num_users):
    client_i = "client_"+str(i+1)
    print(client_i)
    client_i_pyu = sf.PYU(client_i)
    client_i = Client(args=args, dataset=train_dataset, idxs=user_groups[i],device=client_i_pyu)
    
    clients.append(client_i)
server_pyu = sf.PYU("server")
server = Server(args,device = server_pyu)
print("clients",clients)
for idx,local_model in enumerate(clients):
    print(idx,local_model)

FedProto_taskheter(args, train_dataset, test_dataset, user_groups, user_groups_lt, classes_list,clients, server,server_pyu)



Experimental details:
    Model     : cnn
    Optimizer : sgd
    Learning  : 0.01
    Global Rounds   : 100

    Federated parameters:
    Non-IID
    Fraction of users  : 0.04
    Local Batch size   : 4
    Local Epochs       : 1

cuda
True
user 1: 5-way 100-shot
classes: [0 1 4 7 9]
user 2: 5-way 100-shot
classes: [0 1 2 6 8]
user 3: 4-way 100-shot
classes: [0 3 6 8]
user 4: 3-way 100-shot
classes: [0 5 7]
user 5: 2-way 99-shot
classes: [7 9]
user 6: 2-way 99-shot
classes: [1 2]
user 7: 2-way 100-shot
classes: [1 2]
user 8: 3-way 99-shot
classes: [0 7 8]
local test classes: [0 1 4 7 9]
local test classes: [0 1 2 6 8]
local test classes: [0 3 6 8]
local test classes: [0 5 7]
local test classes: [7 9]
local test classes: [1 2]
local test classes: [1 2]
local test classes: [0 7 8]


2024-11-04 15:51:47,459	INFO worker.py:1538 -- Started a local Ray instance.
INFO:root:Create proxy actor <class '__main__.Client'> with party client_1.
INFO:root:Create proxy actor <class '__main__.Client'> with party client_2.
INFO:root:Create proxy actor <class '__main__.Client'> with party client_3.


client_1
client_2
client_3


INFO:root:Create proxy actor <class '__main__.Client'> with party client_4.
INFO:root:Create proxy actor <class '__main__.Client'> with party client_5.


client_4
client_5


INFO:root:Create proxy actor <class '__main__.Client'> with party client_6.
INFO:root:Create proxy actor <class '__main__.Client'> with party client_7.


client_6
client_7


INFO:root:Create proxy actor <class '__main__.Client'> with party client_8.
INFO:root:Create proxy actor <class '__main__.Server'> with party server.


client_8
clients [<__main__.ActorProxy(Client) object at 0x7f0f65e91880>, <__main__.ActorProxy(Client) object at 0x7f0f6633fa30>, <__main__.ActorProxy(Client) object at 0x7f0f65f03760>, <__main__.ActorProxy(Client) object at 0x7f0f6635d910>, <__main__.ActorProxy(Client) object at 0x7f0ed5ed4040>, <__main__.ActorProxy(Client) object at 0x7f0f6634dfa0>, <__main__.ActorProxy(Client) object at 0x7f0f663bdbe0>, <__main__.ActorProxy(Client) object at 0x7f0ed602be20>]
0 <__main__.ActorProxy(Client) object at 0x7f0f65e91880>
1 <__main__.ActorProxy(Client) object at 0x7f0f6633fa30>
2 <__main__.ActorProxy(Client) object at 0x7f0f65f03760>
3 <__main__.ActorProxy(Client) object at 0x7f0f6635d910>
4 <__main__.ActorProxy(Client) object at 0x7f0ed5ed4040>
5 <__main__.ActorProxy(Client) object at 0x7f0f6634dfa0>
6 <__main__.ActorProxy(Client) object at 0x7f0f663bdbe0>
7 <__main__.ActorProxy(Client) object at 0x7f0ed602be20>


  0%|                                                                                           | 0/100 [00:00<?, ?it/s]


 | Global Training Round : 1 |







  1%|▊                                                                                  | 1/100 [00:02<04:04,  2.47s/it]


 | Global Training Round : 2 |



  2%|█▋                                                                                 | 2/100 [00:03<02:12,  1.35s/it]


 | Global Training Round : 3 |



  3%|██▍                                                                                | 3/100 [00:03<01:39,  1.02s/it]


 | Global Training Round : 4 |



  4%|███▎                                                                               | 4/100 [00:04<01:21,  1.17it/s]


 | Global Training Round : 5 |



  5%|████▏                                                                              | 5/100 [00:04<01:09,  1.36it/s]


 | Global Training Round : 6 |



  6%|████▉                                                                              | 6/100 [00:05<01:00,  1.54it/s]


 | Global Training Round : 7 |



  7%|█████▊                                                                             | 7/100 [00:05<00:55,  1.68it/s]


 | Global Training Round : 8 |



  8%|██████▋                                                                            | 8/100 [00:06<00:51,  1.78it/s]


 | Global Training Round : 9 |



  9%|███████▍                                                                           | 9/100 [00:06<00:49,  1.85it/s]


 | Global Training Round : 10 |



 10%|████████▏                                                                         | 10/100 [00:07<00:46,  1.95it/s]


 | Global Training Round : 11 |



 11%|█████████                                                                         | 11/100 [00:07<00:41,  2.16it/s]


 | Global Training Round : 12 |



 12%|█████████▊                                                                        | 12/100 [00:07<00:37,  2.38it/s]


 | Global Training Round : 13 |



 13%|██████████▋                                                                       | 13/100 [00:08<00:35,  2.48it/s]


 | Global Training Round : 14 |



 14%|███████████▍                                                                      | 14/100 [00:08<00:32,  2.61it/s]


 | Global Training Round : 15 |



 15%|████████████▎                                                                     | 15/100 [00:08<00:31,  2.71it/s]


 | Global Training Round : 16 |



 16%|█████████████                                                                     | 16/100 [00:09<00:30,  2.73it/s]


 | Global Training Round : 17 |



 17%|█████████████▉                                                                    | 17/100 [00:09<00:29,  2.79it/s]


 | Global Training Round : 18 |



 18%|██████████████▊                                                                   | 18/100 [00:09<00:28,  2.88it/s]


 | Global Training Round : 19 |



 19%|███████████████▌                                                                  | 19/100 [00:10<00:27,  2.93it/s]


 | Global Training Round : 20 |



 20%|████████████████▍                                                                 | 20/100 [00:10<00:26,  3.00it/s]


 | Global Training Round : 21 |



 21%|█████████████████▏                                                                | 21/100 [00:10<00:26,  3.03it/s]


 | Global Training Round : 22 |



 22%|██████████████████                                                                | 22/100 [00:11<00:25,  3.10it/s]


 | Global Training Round : 23 |



 23%|██████████████████▊                                                               | 23/100 [00:11<00:25,  3.04it/s]


 | Global Training Round : 24 |



 24%|███████████████████▋                                                              | 24/100 [00:11<00:24,  3.11it/s]


 | Global Training Round : 25 |



 25%|████████████████████▌                                                             | 25/100 [00:12<00:24,  3.09it/s]


 | Global Training Round : 26 |



 26%|█████████████████████▎                                                            | 26/100 [00:12<00:23,  3.12it/s]


 | Global Training Round : 27 |



 27%|██████████████████████▏                                                           | 27/100 [00:12<00:23,  3.05it/s]


 | Global Training Round : 28 |



 28%|██████████████████████▉                                                           | 28/100 [00:13<00:26,  2.76it/s]


 | Global Training Round : 29 |



 29%|███████████████████████▊                                                          | 29/100 [00:13<00:26,  2.70it/s]


 | Global Training Round : 30 |



 30%|████████████████████████▌                                                         | 30/100 [00:13<00:25,  2.78it/s]


 | Global Training Round : 31 |



 31%|█████████████████████████▍                                                        | 31/100 [00:14<00:24,  2.78it/s]


 | Global Training Round : 32 |



 32%|██████████████████████████▏                                                       | 32/100 [00:14<00:24,  2.81it/s]


 | Global Training Round : 33 |



 33%|███████████████████████████                                                       | 33/100 [00:15<00:23,  2.81it/s]


 | Global Training Round : 34 |



 34%|███████████████████████████▉                                                      | 34/100 [00:15<00:23,  2.82it/s]


 | Global Training Round : 35 |



 35%|████████████████████████████▋                                                     | 35/100 [00:15<00:23,  2.82it/s]


 | Global Training Round : 36 |



 36%|█████████████████████████████▌                                                    | 36/100 [00:16<00:22,  2.88it/s]


 | Global Training Round : 37 |



 37%|██████████████████████████████▎                                                   | 37/100 [00:16<00:22,  2.86it/s]


 | Global Training Round : 38 |



 38%|███████████████████████████████▏                                                  | 38/100 [00:16<00:20,  2.96it/s]


 | Global Training Round : 39 |



 39%|███████████████████████████████▉                                                  | 39/100 [00:16<00:18,  3.29it/s]


 | Global Training Round : 40 |



 40%|████████████████████████████████▊                                                 | 40/100 [00:17<00:19,  3.12it/s]


 | Global Training Round : 41 |



 41%|█████████████████████████████████▌                                                | 41/100 [00:17<00:19,  3.10it/s]


 | Global Training Round : 42 |



 42%|██████████████████████████████████▍                                               | 42/100 [00:17<00:18,  3.09it/s]


 | Global Training Round : 43 |



 43%|███████████████████████████████████▎                                              | 43/100 [00:18<00:18,  3.03it/s]


 | Global Training Round : 44 |



 44%|████████████████████████████████████                                              | 44/100 [00:18<00:18,  2.99it/s]


 | Global Training Round : 45 |



 45%|████████████████████████████████████▉                                             | 45/100 [00:19<00:18,  2.98it/s]


 | Global Training Round : 46 |



 46%|█████████████████████████████████████▋                                            | 46/100 [00:19<00:17,  3.01it/s]


 | Global Training Round : 47 |



 47%|██████████████████████████████████████▌                                           | 47/100 [00:19<00:17,  3.03it/s]


 | Global Training Round : 48 |



 48%|███████████████████████████████████████▎                                          | 48/100 [00:20<00:17,  3.02it/s]


 | Global Training Round : 49 |



 49%|████████████████████████████████████████▏                                         | 49/100 [00:20<00:16,  3.08it/s]


 | Global Training Round : 50 |



 50%|█████████████████████████████████████████                                         | 50/100 [00:20<00:15,  3.13it/s]


 | Global Training Round : 51 |



 51%|█████████████████████████████████████████▊                                        | 51/100 [00:20<00:15,  3.15it/s]


 | Global Training Round : 52 |



 52%|██████████████████████████████████████████▋                                       | 52/100 [00:21<00:15,  3.04it/s]


 | Global Training Round : 53 |



 53%|███████████████████████████████████████████▍                                      | 53/100 [00:21<00:15,  3.07it/s]


 | Global Training Round : 54 |



 54%|████████████████████████████████████████████▎                                     | 54/100 [00:21<00:15,  3.03it/s]


 | Global Training Round : 55 |



 55%|█████████████████████████████████████████████                                     | 55/100 [00:22<00:14,  3.01it/s]


 | Global Training Round : 56 |



 56%|█████████████████████████████████████████████▉                                    | 56/100 [00:22<00:14,  3.01it/s]


 | Global Training Round : 57 |



 57%|██████████████████████████████████████████████▋                                   | 57/100 [00:22<00:14,  2.87it/s]


 | Global Training Round : 58 |



 58%|███████████████████████████████████████████████▌                                  | 58/100 [00:23<00:14,  2.87it/s]


 | Global Training Round : 59 |



 59%|████████████████████████████████████████████████▍                                 | 59/100 [00:23<00:13,  2.95it/s]


 | Global Training Round : 60 |



 60%|█████████████████████████████████████████████████▏                                | 60/100 [00:23<00:13,  3.01it/s]


 | Global Training Round : 61 |



 61%|██████████████████████████████████████████████████                                | 61/100 [00:24<00:13,  2.92it/s]


 | Global Training Round : 62 |



 62%|██████████████████████████████████████████████████▊                               | 62/100 [00:24<00:12,  2.98it/s]


 | Global Training Round : 63 |



 63%|███████████████████████████████████████████████████▋                              | 63/100 [00:24<00:12,  3.02it/s]


 | Global Training Round : 64 |



 64%|████████████████████████████████████████████████████▍                             | 64/100 [00:25<00:12,  2.98it/s]


 | Global Training Round : 65 |



 65%|█████████████████████████████████████████████████████▎                            | 65/100 [00:25<00:11,  2.96it/s]


 | Global Training Round : 66 |



 66%|██████████████████████████████████████████████████████                            | 66/100 [00:26<00:11,  2.91it/s]


 | Global Training Round : 67 |



 67%|██████████████████████████████████████████████████████▉                           | 67/100 [00:26<00:11,  2.92it/s]


 | Global Training Round : 68 |



 68%|███████████████████████████████████████████████████████▊                          | 68/100 [00:26<00:11,  2.89it/s]


 | Global Training Round : 69 |



 69%|████████████████████████████████████████████████████████▌                         | 69/100 [00:27<00:10,  2.93it/s]


 | Global Training Round : 70 |



 70%|█████████████████████████████████████████████████████████▍                        | 70/100 [00:27<00:10,  2.90it/s]


 | Global Training Round : 71 |



 71%|██████████████████████████████████████████████████████████▏                       | 71/100 [00:27<00:09,  2.91it/s]


 | Global Training Round : 72 |



 72%|███████████████████████████████████████████████████████████                       | 72/100 [00:28<00:09,  2.96it/s]


 | Global Training Round : 73 |



 73%|███████████████████████████████████████████████████████████▊                      | 73/100 [00:28<00:09,  2.95it/s]


 | Global Training Round : 74 |



 74%|████████████████████████████████████████████████████████████▋                     | 74/100 [00:28<00:08,  2.98it/s]


 | Global Training Round : 75 |



 75%|█████████████████████████████████████████████████████████████▌                    | 75/100 [00:29<00:08,  2.99it/s]


 | Global Training Round : 76 |



 76%|██████████████████████████████████████████████████████████████▎                   | 76/100 [00:29<00:08,  2.99it/s]


 | Global Training Round : 77 |



 77%|███████████████████████████████████████████████████████████████▏                  | 77/100 [00:29<00:07,  2.95it/s]


 | Global Training Round : 78 |



 78%|███████████████████████████████████████████████████████████████▉                  | 78/100 [00:30<00:07,  2.93it/s]


 | Global Training Round : 79 |



 79%|████████████████████████████████████████████████████████████████▊                 | 79/100 [00:30<00:07,  2.90it/s]


 | Global Training Round : 80 |



 80%|█████████████████████████████████████████████████████████████████▌                | 80/100 [00:30<00:06,  2.89it/s]


 | Global Training Round : 81 |



 81%|██████████████████████████████████████████████████████████████████▍               | 81/100 [00:31<00:06,  2.85it/s]


 | Global Training Round : 82 |



 82%|███████████████████████████████████████████████████████████████████▏              | 82/100 [00:31<00:06,  2.87it/s]


 | Global Training Round : 83 |



 83%|████████████████████████████████████████████████████████████████████              | 83/100 [00:31<00:05,  2.88it/s]


 | Global Training Round : 84 |



 84%|████████████████████████████████████████████████████████████████████▉             | 84/100 [00:32<00:05,  2.89it/s]


 | Global Training Round : 85 |



 85%|█████████████████████████████████████████████████████████████████████▋            | 85/100 [00:32<00:05,  2.76it/s]


 | Global Training Round : 86 |



 86%|██████████████████████████████████████████████████████████████████████▌           | 86/100 [00:32<00:05,  2.72it/s]


 | Global Training Round : 87 |



 87%|███████████████████████████████████████████████████████████████████████▎          | 87/100 [00:33<00:04,  2.83it/s]


 | Global Training Round : 88 |



 88%|████████████████████████████████████████████████████████████████████████▏         | 88/100 [00:33<00:04,  2.83it/s]


 | Global Training Round : 89 |



 89%|████████████████████████████████████████████████████████████████████████▉         | 89/100 [00:34<00:03,  2.85it/s]


 | Global Training Round : 90 |



 90%|█████████████████████████████████████████████████████████████████████████▊        | 90/100 [00:34<00:03,  2.93it/s]


 | Global Training Round : 91 |



 91%|██████████████████████████████████████████████████████████████████████████▌       | 91/100 [00:34<00:03,  2.96it/s]


 | Global Training Round : 92 |



 92%|███████████████████████████████████████████████████████████████████████████▍      | 92/100 [00:34<00:02,  3.00it/s]


 | Global Training Round : 93 |



 93%|████████████████████████████████████████████████████████████████████████████▎     | 93/100 [00:35<00:02,  3.02it/s]


 | Global Training Round : 94 |



 94%|█████████████████████████████████████████████████████████████████████████████     | 94/100 [00:35<00:01,  3.03it/s]


 | Global Training Round : 95 |



 95%|█████████████████████████████████████████████████████████████████████████████▉    | 95/100 [00:35<00:01,  3.03it/s]


 | Global Training Round : 96 |



 96%|██████████████████████████████████████████████████████████████████████████████▋   | 96/100 [00:36<00:01,  3.02it/s]


 | Global Training Round : 97 |



 97%|███████████████████████████████████████████████████████████████████████████████▌  | 97/100 [00:36<00:00,  3.07it/s]


 | Global Training Round : 98 |



 98%|████████████████████████████████████████████████████████████████████████████████▎ | 98/100 [00:36<00:00,  2.95it/s]


 | Global Training Round : 99 |



 99%|█████████████████████████████████████████████████████████████████████████████████▏| 99/100 [00:37<00:00,  2.90it/s]


 | Global Training Round : 100 |



100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:37<00:00,  2.65it/s]






[2m[36m(Server pid=91420)[0m | User: 0 | Global Test Acc w/o protos: 0.970
[2m[36m(Server pid=91420)[0m | User: 0 | Global Test Acc with protos: 0.97000
[2m[36m(Server pid=91420)[0m | User: 1 | Global Test Acc w/o protos: 0.980
[2m[36m(Server pid=91420)[0m | User: 1 | Global Test Acc with protos: 0.98500




[2m[36m(Server pid=91420)[0m | User: 2 | Global Test Acc w/o protos: 0.986
[2m[36m(Server pid=91420)[0m | User: 2 | Global Test Acc with protos: 0.98661
[2m[36m(Server pid=91420)[0m | User: 3 | Global Test Acc w/o protos: 0.987
[2m[36m(Server pid=91420)[0m | User: 3 | Global Test Acc with protos: 0.98750
[2m[36m(Server pid=91420)[0m | User: 4 | Global Test Acc w/o protos: 0.987
[2m[36m(Server pid=91420)[0m | User: 4 | Global Test Acc with protos: 0.98684
[2m[36m(Server pid=91420)[0m | User: 5 | Global Test Acc w/o protos: 0.988
[2m[36m(Server pid=91420)[0m | User: 5 | Global Test Acc with protos: 0.98750
[2m[36m(Server pid=91420)[0m | User: 6 | Global Test Acc w/o protos: 0.988
[2m[36m(Server pid=91420)[0m | User: 6 | Global Test Acc with protos: 0.98859
[2m[36m(Server pid=91420)[0m | User: 7 | Global Test Acc w/o protos: 0.989
[2m[36m(Server pid=91420)[0m | User: 7 | Global Test Acc with protos: 0.98942
For all users (with protos), mean of test acc 