In [1]:
import torch
from torchvision import datasets, transforms
import copy
import numpy as np
import random
from tqdm import trange

# 参数设置

In [2]:
#调参：for MNIST, B = 100, E =1, η = 0.01, decay rate = 0.995;
###   for CIFAR-10, B = 100, E = 1, η = 0.1, decay rate = 0.992
import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    
    # federated arguments
    parser.add_argument('--fed', type=str, default='fedavg', help="federated optimization algorithm")
    parser.add_argument('--mu', type=float, default=1e-2, help='hyper parameter for fedprox')
    parser.add_argument('--rounds', type=int, default=100, help="total number of communication rounds")   #default=200
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.5, help="fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=10, help="number of local epochs: E")  #default=10
    parser.add_argument('--min_le', type=int, default=5, help="minimum number of local epoch")
    parser.add_argument('--max_le', type=int, default=15, help="maximum number of minimum local epoch")
    parser.add_argument('--local_bs', type=int, default=20, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=20, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.001, help="client learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--classwise', type=int, default=1000, help="number of images for each class (global dataset)")
    parser.add_argument('--alpha', type=float, default=0.05, help="random portion of global dataset")
    parser.add_argument('--beta', type=int, default=0, help="DROP_RATE")

    # other arguments
    parser.add_argument('--dataset', type=str, default='fashion-mnist', help="name of dataset")
    parser.add_argument('--model', type=str, default='cnn', help='model nammniste')
    parser.add_argument('--sampling', type=str, default='noniid_two', help="sampling method")
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of images")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    parser.add_argument('--sys_homo', action='store_true', help='no system heterogeneity')
    parser.add_argument('--tsboard', action='store_true', help='tensorboard')
    parser.add_argument('--debug', action='store_true', help='debug')
    
    args = parser.parse_args(args=[])
    
    return args

# 网络结构

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

class CNN_v1(nn.Module):
    def __init__(self, args):
        super(CNN_v1, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN_v2(nn.Module):
    def __init__(self, args):
        super(CNN_v2, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Alexnet(nn.Module):

    def __init__(self, args):
        super(Alexnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(args.num_channels, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, args.num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# 数据处理

### 划分共享数据

In [4]:
def uniform_distribute(dataset, args):  #各类取α*classwise数量的数据组成全局共享数据
    globally_shared_data_idx = []
    
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)
        globally_shared_data = np.random.choice(specific_class, int(args.alpha * args.classwise), replace=False)  #replace表示不可取相同的元素
        
        globally_shared_data_idx = globally_shared_data_idx + list(globally_shared_data)
    
    return globally_shared_data_idx   #返回共享数据的索引

def train_dg_split(dataset, args): 
    dg_idx = []
    train_idx = []
    idxs = np.arange(len(dataset))

    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))  #沿着竖直方向进行堆叠
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
    
    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)  #从数组中提取符合条件的元素
        
        dg = np.random.choice(specific_class, args.classwise, replace=False)  #从每一类中选取1000（classwise）个数据，且不重复
        
        train_tmp = set(specific_class)-set(dg)
        
        dg_idx = dg_idx + list(dg)
        
        train_idx = train_idx + list(train_tmp)
    
    return dg_idx, train_idx      #dg：从每一类中选取1000（classwise）个数据组成服务器专有数据 train_tmp：每一类中dg挑选剩下的组成


### 划分IID与Non-IID数据

In [5]:
def iid(dataset, num_users):
    """
    Sample I.I.D. client data from dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    
    for i in range(num_users):
        
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))   #随机获得索引，且不重复
        all_idxs = list(set(all_idxs) - dict_users[i])  #set（）创建一个无序不重复元素集，可进行关系测试，删除重复数据，还可以计算交集、差集、并集等
    
    return dict_users

def noniid(dataset, args):
    """
    Sample non-I.I.D client data from dataset
    -> 不同的客户端拥有不同的数据量
    :param dataset:
    :param num_users:
    :return:
    """
    num_dataset = len(dataset)
    idx = np.arange(num_dataset)
    dict_users = {i: list() for i in range(args.num_users)}
    
    min_num = 100
    max_num = 700

    random_num_size = np.random.randint(min_num, max_num+1, size=args.num_users)
    print(f"Total number of datasets owned by clients : {sum(random_num_size)}")

    # total dataset should be larger or equal to sum of splitted dataset.
    assert num_dataset >= sum(random_num_size)  #assert（断言）用于判断一个表达式，在表达式条件为 false 的时候触发异常

    # divide and assign
    for i, rand_num in enumerate(random_num_size):

        rand_set = set(np.random.choice(idx, rand_num, replace=False))
        idx = list(set(idx) - rand_set)
        dict_users[i] = rand_set

    return dict_users

def noniid_one(dataset, args):
    num_items = int(len(dataset)/args.num_users)-100
    print("num_items:",num_items)
    dict_users = {}
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)
        print("len_specific_class",len(specific_class))
        for j in range(10):
            dict_users[i*10+j] =np.random.choice(specific_class,  num_items, replace=False)
            specific_class=list(set(specific_class)-set(dict_users[i*10+j]))
    return dict_users

def noniid_two(dataset, args):    ##实现了每个客户端两类数据，但是未遍历整个数据集
    num_items = int((int(len(dataset)/args.num_users)-100)/2)
    print("num_items",num_items)
    dict_users = {}
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    for i in range(args.num_classes):
        specific_class_one = np.extract(labels == i, idxs)
        L=list(range(args.num_classes))
        print(L,i)
        specific_class_two_idx = random.choice(L)
        while specific_class_two_idx ==i:
            specific_class_two_idx = random.choice(L)
        print(specific_class_two_idx)
        specific_class_two = np.extract(labels == specific_class_two_idx, idxs)
        print("len_specific_class_one",len(specific_class_one))
        for j in range(10):
            dict_users[i*10+j] =np.random.choice(specific_class_one,num_items, replace=False)
            specific_class_one=list(set(specific_class_one)-set(dict_users[i*10+j]))
            dict_users[i*10+j] =list(set(dict_users[i*10+j])|set(np.random.choice(specific_class_two,  num_items, replace=False)))
            specific_class_two=list(set(specific_class_two)-set(dict_users[i*10+j]))
            print(len(dict_users[i*10+j]))

    return dict_users

In [6]:
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
print( "GPU:",torch.cuda.is_available())
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)     
torch.cuda.manual_seed_all(args.seed)

GPU: True


In [None]:
# load dataset and split users
if args.dataset == 'mnist':
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  #ToTensor()将图片变成Tensor类型数据
    dataset = datasets.MNIST('../autodl-tmp/data/mnist/', train=True, download=True, transform=trans_mnist)
    dataset_test = datasets.MNIST('../autodl-tmp/data/mnist/', train=False, download=True, transform=trans_mnist)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)  #dg是服务器拥有的共享数据，dataset_train是准备分配给客户端的数据

    dg_idx, dataset_train_idx = train_dg_split(dataset, args)      #返回列表形式的索引

    dg.data, dataset_train.data = dataset.data[dg_idx], dataset.data[dataset_train_idx]
    dg.targets, dataset_train.targets = dataset.targets[dg_idx], dataset.targets[dataset_train_idx]  #至此获得服务器专有数据  以及待分发给客户端的数据

    # sample users
    if args.sampling == 'iid':     #给客户端分配数据
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')

elif args.dataset == 'fashion-mnist':
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  #ToTensor()将图片变成Tensor类型数据
    dataset = datasets.FashionMNIST('../autodl-tmp/data/fashion-mnist/', train=True, download=True, transform=trans_mnist)
    dataset_test = datasets.FashionMNIST('../autodl-tmp/data/fashion-mnist/', train=False, download=True, transform=trans_mnist)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)  #dg是服务器拥有的共享数据，dataset_train是准备分配给客户端的数据

    dg_idx, dataset_train_idx = train_dg_split(dataset, args)      #返回列表形式的索引

    dg.data, dataset_train.data = dataset.data[dg_idx], dataset.data[dataset_train_idx]
    dg.targets, dataset_train.targets = dataset.targets[dg_idx], dataset.targets[dataset_train_idx]  #至此获得服务器专有数据  以及待分发给客户端的数据

    # sample users
    if args.sampling == 'iid':     #给客户端分配数据
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')        

       
elif args.dataset == 'cifar':
    trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset = datasets.CIFAR10('../autodl-tmp/data/cifar-10', train=True, download=True, transform=trans_cifar)
    dataset_test = datasets.CIFAR10('../autodl-tmp/data/cifar-10', train=False, download=True, transform=trans_cifar)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)

    dg_idx, dataset_train_idx = train_dg_split(dataset, args)

    dg.targets.clear()
    dataset_train.targets.clear()


    dg.data, dataset_train.data = dataset.data[dg_idx], dataset.data[dataset_train_idx]

    for i in list(dg_idx):
        dg.targets.append(dataset[i][1])
    for i in list(dataset_train_idx):
        dataset_train.targets.append(dataset[i][1])

    # sample users    #获取不同数据分布下的数据索引
    if args.sampling == 'iid':
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')

else:
    exit('Error: unrecognized dataset')

img_size = dataset_train[0][0].shape

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../autodl-tmp/data/fashion-mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

In [None]:
def FedAvg(w, args):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        tmp = torch.zeros_like(w[0][k], dtype = torch.float32).to(args.device)
        for i in range(len(w)):
            tmp += w[i][k]
        tmp = torch.div(tmp, len(w))
        w_avg[k].copy_(tmp)
    return w_avg

## 训练更新

In [None]:
#测试代码
from torch.utils.data import DataLoader, Dataset
from random import randint

def test_img(net_g, datatest, args):
    net_g.eval()
    
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    with torch.no_grad():
        for idx, (data, target) in enumerate(data_loader):
            if args.gpu != -1:
                data, target = data.to(args.device), target.to(args.device)
            log_probs = net_g(data)
            # sum up batch loss
            test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
            
            # get the index of the max log-probability
            y_pred = log_probs.data.max(1, keepdim=True)[1]
            correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

        test_loss /= len(data_loader.dataset)
        accuracy = 100.00 * correct / len(data_loader.dataset)
    
        if args.verbose:
            print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
                test_loss, correct, len(data_loader.dataset), accuracy))
        return accuracy, test_loss
    
    
#更新
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class ModelUpdate(object):   #客户端训练更新函数
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()   #定义损失函数
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)  #数据处理

    def train(self, local_net, net):   #重要代码
        
        net.train()
        
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)   #优化器定义
        epoch_loss = []

        if not self.args.sys_homo:   #给定本地训练时的Epoch
            local_ep = self.args.local_ep
        else:
            local_ep = randint(self.args.min_le, self.args.max_le)

        for iter in range(local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):   #迭代训练
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                net.zero_grad()    #在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算，因此需要对每个batch调用一遍zero_grad（）将参数梯度置0
                log_probs = net(images)  #预测
                loss = self.loss_func(log_probs, labels)  #计算损失值

                if self.args.fed == 'fedprox':
                    if iter > 0: 
                        for w, w_t in zip(local_net.parameters(), net.parameters()):
                            loss += self.args.mu / 2. * torch.pow(torch.norm(w.data - w_t.data), 2)
                        
                loss.backward()  #反向传播，计算每个参数的梯度
                
                optimizer.step()  #利用梯度更新参数
#                 if not self.args.verbose and batch_idx % 10 == 0:   #显示
#                     print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                         iter, batch_idx * len(images), len(self.ldr_train.dataset),
#                                100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())    
            epoch_loss.append(sum(batch_loss)/len(batch_loss))  #获得这一轮训练这批数据的平均损失
        
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)    #获得在本地进行多轮训练的平均损失

# 训练过程

In [None]:
for drop_rate in[0]:
        #构建神经网咯
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNN_v2(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNN_v1(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'fashion-mnist':
        net_glob = CNN_v1(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')

    print(net_glob)
    import pandas as pd
    import random
    net_glob.train()   #训练模式

    w_glob = net_glob.state_dict()

    # initialization stage of FedAvg
    # initialization_stage = ModelUpdate(args=args, dataset=dataset, idxs=set(dg_idx))    #用dg进行预训练,先指定数据集，再进行训练获得更新后的网络参数与平均损失值
    # w_glob, _ = initialization_stage.train(local_net = copy.deepcopy(net_glob).to(args.device), net = copy.deepcopy(net_glob).to(args.device))
    #初始化模型参数

    def weight_init(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        # 也可以判断是否为conv2d，使用相应的初始化方式 
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
         # 是否为批归一化层
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    #将weight_init应用在子模块上
    net_glob.apply(weight_init)
    w_glob = net_glob.state_dict()

    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]  #将全局模型下发给客户端
    serve_loss=[]
    serve_acc=[]
    # distribute globally shared data (uniform distribution)
    share_idx = uniform_distribute(dg, args)
    for iter in trange(args.rounds):  #对于每一轮通信

        if not args.all_clients:
            w_locals = []

        m = max(int(args.frac * args.num_users), 1) 

        idxs_users = np.random.choice(range(args.num_users), m, replace=False) #选取一定比例的客户端：m个

        for idx in idxs_users:

            # Local update
            local = ModelUpdate(args=args, dataset=dataset_train, idxs=set(list(dict_users[idx])))  #使用客户端有的数据进行训练

            w, loss = local.train(local_net = copy.deepcopy(net_glob).to(args.device), net = copy.deepcopy(net_glob).to(args.device))  #获得更新后的模型参数以及损失值

            if args.all_clients:
                rand = random.randint(0,99)
                drop_communication = range(drop_rate)
                if rand not in drop_communication:
                    w_locals[idx] = copy.deepcopy(w)
            else:
                rand = random.randint(0,99)
                drop_communication = range(drop_rate)
                if rand not in drop_communication:
                    w_locals.append(copy.deepcopy(w))   #丢包率设置


        # update global weights
        print('len(w_locals)',len(w_locals))
        w_glob = FedAvg(w_locals, args)  #服务器聚合模型获得新的全局模型

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)  #获取新的全局模型的参数

        acc_test, loss_test = test_img(net_glob, dataset_test, args)  #开始测试，获得准确率与损失值
        serve_loss.append(loss_test)
        serve_acc.append(acc_test)



        if not args.debug:
            print(f"Round: {iter}")
            print(f"Test accuracy: {acc_test}")
            print(f"Test loss: {loss_test}")
    test_loss=pd.DataFrame(data=serve_loss)
    test_loss_filename='FedAvg_Loss_'+str(args.dataset)+'round:'+str(args.rounds)+str(args.sampling)+'.csv'
    test_loss.to_csv(test_loss_filename,header=None) 

    test_acc=pd.DataFrame(data=serve_acc)
    test_acc_filename='FedAvg_Acc_'+str(args.dataset)+'round:'+str(args.rounds)+str(args.sampling)+'.csv'
    test_acc.to_csv(test_acc_filename,header=None) 

In [None]:
#调参：for MNIST, B = 100, E =1, η = 0.01, decay rate = 0.995;
###   for CIFAR-10, B = 100, E = 1, η = 0.1, decay rate = 0.992