In [1]:
import torch
from torchvision import datasets, transforms
import copy
import numpy as np
import random
from tqdm import trange

# 参数设置

In [2]:
import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    
    # federated arguments
    parser.add_argument('--fed', type=str, default='fedavg', help="federated optimization algorithm")
    parser.add_argument('--mu', type=float, default=1e-2, help='hyper parameter for fedprox')
    parser.add_argument('--rounds', type=int, default=100, help="total number of communication rounds")   #default=200
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=1, help="fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=10, help="number of local epochs: E")  #default=10
    parser.add_argument('--min_le', type=int, default=5, help="minimum number of local epoch")
    parser.add_argument('--max_le', type=int, default=15, help="maximum number of minimum local epoch")
    parser.add_argument('--local_bs', type=int, default=20, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=20, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.001, help="client learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--classwise', type=int, default=100, help="number of images for each class (global dataset)")
    
    #classwise  1000  800 500 300
    parser.add_argument('--alpha', type=float, default=100, help="test dataset")
    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--model', type=str, default='cnn', help='model name')
    parser.add_argument('--sampling', type=str, default='iid', help="sampling method")
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of images")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=3, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    parser.add_argument('--sys_homo', action='store_true', help='no system heterogeneity')
    parser.add_argument('--tsboard', action='store_true', help='tensorboard')
    parser.add_argument('--debug', action='store_true', help='debug')
    
    args = parser.parse_args(args=[])
    
    return args

# 网络结构

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

class CNN_v1(nn.Module):
    def __init__(self, args):
        super(CNN_v1, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN_v2(nn.Module):
    def __init__(self, args):
        super(CNN_v2, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Alexnet(nn.Module):

    def __init__(self, args):
        super(Alexnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(args.num_channels, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, args.num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# 数据处理

### 划分共享数据

In [4]:
def uniform_distribute(dataset, args):  #各类取100数量的数据组评分数据集
    globally_shared_data_idx = []
    
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)
        globally_shared_data = np.random.choice(specific_class, int(args.alpha), replace=False)  #replace表示不可取相同的元素
        
        globally_shared_data_idx = globally_shared_data_idx + list(globally_shared_data)
    
    return globally_shared_data_idx   #返回共享数据的索引

def train_dg_split(dataset, args): 
    dg_idx = []
    train_idx = []
    idxs = np.arange(len(dataset))

    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))  #沿着竖直方向进行堆叠
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
    
    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)  #从数组中提取符合条件的元素
        
        dg = np.random.choice(specific_class, args.classwise, replace=False)  #从每一类中选取1000（classwise）个数据，且不重复
        
        train_tmp = set(specific_class)-set(dg)
        
        dg_idx = dg_idx + list(dg)
        
        train_idx = train_idx + list(train_tmp)
    
    return dg_idx, train_idx      #dg：从每一类中选取1000（classwise）个数据组成服务器专有数据 train_tmp：每一类中dg挑选剩下的组成


### 划分IID与Non-IID数据

In [5]:
def iid(dataset, num_users):
    """
    Sample I.I.D. client data from dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    
    for i in range(num_users):
        
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))   #随机获得索引，且不重复
        all_idxs = list(set(all_idxs) - dict_users[i])  #set（）创建一个无序不重复元素集，可进行关系测试，删除重复数据，还可以计算交集、差集、并集等
    
    return dict_users

def noniid(dataset, args):
    """
    Sample non-I.I.D client data from dataset
    -> 不同的客户端拥有不同的数据量
    :param dataset:
    :param num_users:
    :return:
    """
    num_dataset = len(dataset)
    idx = np.arange(num_dataset)  #注意此处对dataset_train索引进行重新排序，而不是dataset
    dict_users = {i: list() for i in range(args.num_users)}
    
    min_num = 100
    max_num = 600

    random_num_size = np.random.randint(min_num, max_num+1, size=args.num_users)
    print(f"Total number of datasets owned by clients : {sum(random_num_size)}")

    # total dataset should be larger or equal to sum of splitted dataset.
    assert num_dataset >= sum(random_num_size)  #assert（断言）用于判断一个表达式，在表达式条件为 false 的时候触发异常

    # divide and assign
    for i, rand_num in enumerate(random_num_size):

        rand_set = set(np.random.choice(idx, rand_num, replace=False))
        idx = list(set(idx) - rand_set)
        dict_users[i] = rand_set

    return dict_users
def noniid_one(dataset, args):
    num_items = int(len(dataset)/args.num_users)-100
    print("num_items:",num_items)
    dict_users = {}
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    
    for i in range(args.num_classes):
        specific_class = np.extract(labels == i, idxs)
        print("len_specific_class",len(specific_class))
        for j in range(10):
            dict_users[i*10+j] =np.random.choice(specific_class,  num_items, replace=False)
            specific_class=list(set(specific_class)-set(dict_users[i*10+j]))
    return dict_users

def noniid_two(dataset, args):    ##实现了每个客户端两类数据，但是未遍历整个数据集
    num_items = int((int(len(dataset)/args.num_users)-100)/2)
    print("num_items",num_items)
    dict_users = {}
    idxs = np.arange(len(dataset))
    
    if args.dataset == "mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "fashion-mnist":
        labels = dataset.targets.numpy()
    elif args.dataset == "cifar":
        labels = np.array(dataset.targets)
    else:
        exit('Error: unrecognized dataset')
    
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]    #按照标签排序

    idxs = idxs_labels[0]
    labels = idxs_labels[1]
    for i in range(args.num_classes):
        specific_class_one = np.extract(labels == i, idxs)
        L=list(range(args.num_classes))
        print(L,i)
        specific_class_two_idx = random.choice(L)
        while specific_class_two_idx ==i:
            specific_class_two_idx = random.choice(L)
        print(specific_class_two_idx)
        specific_class_two = np.extract(labels == specific_class_two_idx, idxs)
        print("len_specific_class_one",len(specific_class_one))
        for j in range(10):
            dict_users[i*10+j] =np.random.choice(specific_class_one,num_items, replace=False)
            specific_class_one=list(set(specific_class_one)-set(dict_users[i*10+j]))
            dict_users[i*10+j] =list(set(dict_users[i*10+j])|set(np.random.choice(specific_class_two,  num_items, replace=False)))
            specific_class_two=list(set(specific_class_two)-set(dict_users[i*10+j]))
            print(len(dict_users[i*10+j]))

    return dict_users

In [6]:
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
print( "GPU:",torch.cuda.is_available())
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)     
torch.cuda.manual_seed_all(args.seed)

GPU: True


In [7]:
# load dataset and split users
if args.dataset == 'mnist':
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  #ToTensor()将图片变成Tensor类型数据
    dataset = datasets.MNIST('../autodl-tmp/data/mnist/', train=True, download=True, transform=trans_mnist)
    dataset_test = datasets.MNIST('../autodl-tmp/data/mnist/', train=False, download=True, transform=trans_mnist)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)  #dg是服务器拥有的共享数据，dataset_train是准备分配给客户端的数据
    test_dataset= copy.deepcopy(dataset) #评分数据集
    
    share_idx = uniform_distribute(dg, args)   #评分数据集
    dg_idx, dataset_train_idx = train_dg_split(dataset, args)      #返回列表形式的索引

    dg.data, dataset_train.data ,test_dataset.data= dataset.data[dg_idx], dataset.data[dataset_train_idx], dataset.data[share_idx]
    dg.targets, dataset_train.targets,test_dataset.targets = dataset.targets[dg_idx], dataset.targets[dataset_train_idx], dataset.targets[share_idx]  #至此获得服务器专有数据  以及待分发给客户端的数据

    # sample users
    if args.sampling == 'iid':     #给客户端分配数据
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')

elif args.dataset == 'fashion-mnist':
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  #ToTensor()将图片变成Tensor类型数据
    dataset = datasets.FashionMNIST('../autodl-tmp/data/fashion-mnist/', train=True, download=True, transform=trans_mnist)
    dataset_test = datasets.FashionMNIST('../autodl-tmp/data/fashion-mnist/', train=False, download=True, transform=trans_mnist)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)  #dg是服务器拥有的共享数据，dataset_train是准备分配给客户端的数据
    test_dataset= copy.deepcopy(dataset) #评分数据集
    
    share_idx = uniform_distribute(dg, args)   #评分数据集
    dg_idx, dataset_train_idx = train_dg_split(dataset, args)      #返回列表形式的索引

    dg.data, dataset_train.data ,test_dataset.data= dataset.data[dg_idx], dataset.data[dataset_train_idx], dataset.data[share_idx]
    dg.targets, dataset_train.targets,test_dataset.targets = dataset.targets[dg_idx], dataset.targets[dataset_train_idx], dataset.targets[share_idx]  #至此获得服务器专有数据  以及待分发给客户端的数据

    # sample users
    if args.sampling == 'iid':     #给客户端分配数据
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')   
        
elif args.dataset == 'cifar':
    trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset = datasets.CIFAR10('../autodl-tmp/data/cifar-10', train=True, download=True, transform=trans_cifar)
    dataset_test = datasets.CIFAR10('../autodl-tmp/data/cifar-10', train=False, download=True, transform=trans_cifar)

    dg = copy.deepcopy(dataset)
    dataset_train = copy.deepcopy(dataset)
    test_dataset= copy.deepcopy(dataset) #评分数据集
    
    share_idx = uniform_distribute(dg, args)   #评分数据集

    dg_idx, dataset_train_idx = train_dg_split(dataset, args)

    dg.targets.clear()
    dataset_train.targets.clear()
    test_dataset.targets.clear()


    dg.data, dataset_train.data,test_dataset.data = dataset.data[dg_idx], dataset.data[dataset_train_idx],dataset.data[share_idx]

    for i in list(dg_idx):
        dg.targets.append(dataset[i][1])
    for i in list(dataset_train_idx):
        dataset_train.targets.append(dataset[i][1])
    for i in list(share_idx):
        test_dataset.targets.append(dataset[i][1])

    # sample users    #获取不同数据分布下的数据索引
    if args.sampling == 'iid':
        dict_users = iid(dataset_train, args.num_users)
    elif args.sampling == 'noniid':
        dict_users = noniid(dataset_train, args)
    elif args.sampling == 'noniid_one':
        dict_users = noniid_one(dataset_train, args)
    elif args.sampling == 'noniid_two':
        dict_users = noniid_two(dataset_train, args)
    else:
        exit('Error: unrecognized sampling')

else:
    exit('Error: unrecognized dataset')

img_size = dataset_train[0][0].shape

print(dict_users[0])
# print(dataset_train.targets[list(dict_users[9])])

{10245, 45061, 24583, 34826, 18449, 49169, 40979, 45079, 6171, 6181, 34859, 57393, 47156, 6197, 4150, 30775, 32824, 53307, 45117, 49214, 26688, 4167, 26699, 57420, 30795, 4175, 30800, 55380, 41045, 49239, 22616, 34905, 34917, 8293, 20584, 30825, 16490, 12402, 47218, 4212, 16501, 12403, 14458, 26750, 43139, 36999, 34954, 57483, 34958, 34961, 43154, 24724, 55449, 37017, 10395, 6299, 57499, 53410, 24739, 43172, 10408, 55477, 6333, 26826, 49354, 206, 30927, 14548, 8412, 24801, 49379, 2276, 16611, 47333, 22758, 8424, 45287, 10477, 26862, 49398, 2295, 28918, 30969, 24840, 26894, 55570, 49426, 37143, 43294, 18719, 33056, 55585, 57636, 20783, 2351, 28975, 6451, 22836, 308, 14644, 22846, 35135, 51521, 47426, 322, 4422, 6471, 22859, 26957, 16720, 49495, 22879, 57700, 33126, 6505, 10603, 35180, 53612, 37233, 20849, 29043, 39284, 43379, 22908, 380, 29054, 14718, 24957, 14722, 8582, 4489, 14730, 14731, 39310, 33178, 51628, 4525, 20912, 27061, 35255, 39352, 47549, 20927, 51669, 41434, 37341, 20959, 

In [8]:
def FedAvg(w, args):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        tmp = torch.zeros_like(w[0][k], dtype = torch.float32).to(args.device)
        for i in range(len(w)):
            tmp += w[i][k]
        tmp = torch.div(tmp, len(w))
        w_avg[k].copy_(tmp)
    return w_avg
def FedPSO(w_locals,dic_client_result,k):   #top-K在这里修改就可，然后取模型的平均值即可
    key_list=list(dic_client_result.keys())
    val_list=list(dic_client_result.values())
    t = copy.deepcopy(val_list)
    
    # 求k个最大的数值及其索引
    max_number = []
    max_index = []
    for _ in range(k):
        number = max(t)
        index = t.index(number)
        t[index] = 0
        max_number.append(number)
        max_index.append(index)
    t = []
    
    maxid_list=[]
    for _ in max_index:
        maxid_list.append(key_list[_])   #获得k个最大值的索引
        
    w=[]   #存储k个最好的模型
    for _ in maxid_list:
        w.append(w_locals[_])
        
    w_psg = copy.deepcopy(w[0])
    for k in w_psg.keys():
        tmp = torch.zeros_like(w[0][k], dtype = torch.float32).to(args.device)
        for i in range(len(w)):
            tmp += w[i][k]
        tmp = torch.div(tmp, len(w))
        w_psg[k].copy_(tmp)
    return w_psg

## 训练更新

In [9]:
#测试代码
from torch.utils.data import DataLoader, Dataset
from random import randint

def test_img(net_g, datatest, args):
    net_g.eval()
    
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    with torch.no_grad():
        for idx, (data, target) in enumerate(data_loader):
            if args.gpu != -1:
                data, target = data.to(args.device), target.to(args.device)
            log_probs = net_g(data)
            # sum up batch loss
            test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
            
            # get the index of the max log-probability
            y_pred = log_probs.data.max(1, keepdim=True)[1]
            correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

        test_loss /= len(data_loader.dataset)
        accuracy = 100.00 * correct / len(data_loader.dataset)
    
#         if args.verbose:
#             print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
#                 test_loss, correct, len(data_loader.dataset), accuracy))
        return accuracy, test_loss
    
    
#更新
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class ModelUpdate(object):   #客户端训练更新函数
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()   #定义损失函数
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)  #数据处理

    def train(self, local_net, net):   #重要代码
        
        net.train()
        
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)   #优化器定义
        epoch_loss = []

        if not self.args.sys_homo:   #给定本地训练时的Epoch
            local_ep = self.args.local_ep
        else:
            local_ep = randint(self.args.min_le, self.args.max_le)

        for iter in range(local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):   #迭代训练
                images, labels = images.to(self.args.device), labels.to(self.args.device)
#                 print(labels)
                net.zero_grad()    #在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算，因此需要对每个batch调用一遍zero_grad（）将参数梯度置0
                log_probs = net(images)  #预测
                loss = self.loss_func(log_probs, labels)  #计算损失值

                if self.args.fed == 'fedprox':
                    if iter > 0: 
                        for w, w_t in zip(local_net.parameters(), net.parameters()):
                            loss += self.args.mu / 2. * torch.pow(torch.norm(w.data - w_t.data), 2)
                        
                loss.backward()  #反向传播，计算每个参数的梯度
                
                optimizer.step()  #利用梯度更新参数
#                 if not self.args.verbose and batch_idx % 10 == 0:   #显示
#                     print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                         iter, batch_idx * len(images), len(self.ldr_train.dataset),
#                                100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())    
            epoch_loss.append(sum(batch_loss)/len(batch_loss))  #获得这一轮训练这批数据的平均损失
        
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)    #获得在本地进行多轮训练的平均损失

# 训练过程

In [None]:
k=10
c2=2.0
for c1 in [1.0]:
    drop_rate=0
    #构建神经网咯
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNN_v2(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNN_v1(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'fashion-mnist':
        net_glob = CNN_v1(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')

    print(net_glob)
    from tqdm import tqdm
    import pandas as pd
    net_glob.train()   #训练模式

    acc=0.3
#     c1=1.0      #0.7
#     c2=2.0           #1.4
    # copy weights
#     w_glob = net_glob.state_dict()
    w_glob = net_glob.state_dict()

    # initialization stage of FedAvg
    # initialization_stage = ModelUpdate(args=args, dataset=dataset, idxs=set(dg_idx))    #用dg进行预训练,先指定数据集，再进行训练获得更新后的网络参数与平均损失值
    # w_glob, _ = initialization_stage.train(local_net = copy.deepcopy(net_glob).to(args.device), net = copy.deepcopy(net_glob).to(args.device))
    #初始化模型参数

    def weight_init(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        # 也可以判断是否为conv2d，使用相应的初始化方式 
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
         # 是否为批归一化层
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    #将weight_init应用在子模块上
    net_glob.apply(weight_init)
    w_glob = net_glob.state_dict()

    # initialization stage of FedShare
#     initialization_stage = ModelUpdate(args=args, dataset=dataset, idxs=set(dg_idx))    #用dg进行预训练,先指定数据集，再进行训练获得更新后的网络参数与平均损失值
#     w_glob, _ = initialization_stage.train(local_net = copy.deepcopy(net_glob).to(args.device), net = copy.deepcopy(net_glob).to(args.device))

#     net_glob.load_state_dict(w_glob)  #预训练后模型参数更新

    # if args.all_clients:
    #     print("Aggregation over all clients")
    w_locals = [w_glob for i in range(args.num_users)]  #将全局模型下发给客户端
    w_best_locals = [w_glob for i in range(args.num_users)]     #局部最优模型

    server_evaluate_acc = []   #存储每一轮通信后全局模型的性能
    server_evaluate_loss = []

    global_best_model =copy.deepcopy(net_glob)   #初始化最优模型
    acc_test_pre, loss_test_pre = test_img(global_best_model, test_dataset, args) 
    global_best_score = acc_test_pre   #初始化最优模型评分(准确率)

    step_model=copy.deepcopy(net_glob) #作为第一轮训练前服务器发给客户端的模型
    stepwp_model=copy.deepcopy(net_glob)

    for iter in trange(args.rounds):  #对于每一轮通信
        client_result = []   #用之前归零，用于存储每一轮通信各个客户端的编号及其准确率值[(),(),,,,,()]   再进行dic = dict(client_result)

    #     if not args.all_clients:
    #         w_locals = []

        m = max(int(args.frac * args.num_users), 1) 

        idxs_users = np.random.choice(range(args.num_users), m, replace=False) #选取一定比例的客户端：m个  编号也是从这里开始，不会取 w_locals里所有的

        for idx in tqdm(idxs_users):   #对每个客户端进行训练   w_locals、step_wp、step_wg的更新  step_model为训练的模型  w_locals[idx]为其参数
            #step_model是在这轮通信训练前客户端收到服务器下发的模型，是客户端共有的

            #粒子更新
            step_w=step_model.state_dict()  ##   #客户端收到服务器下发的参数
            param_names= list(step_w)
            velocities = [ [] for i in range(len(param_names))]
            for i, name in enumerate(param_names):
                shape =step_w[name].shape
                velocities[i] = torch.rand(*shape)/5 - 0.1

    #         step_w_copy=copy.deepcopy(step_w)
            step_wp=w_best_locals[idx]    ##  局部最优
            step_wg=global_best_model.state_dict()  
            new_weight = [None] * len(step_w)
            local_rand, global_rand = torch.rand(1).to(args.device), torch.rand(1).to(args.device)
            for i,layer in enumerate(step_w):
                new_v=acc*velocities[i]
                new_v = new_v.to(args.device)
                new_v=new_v+c1*(step_wp[layer]-step_w[layer])*local_rand
                new_v=new_v+c2*(step_wg[layer]-step_w[layer])*global_rand 
                velocities[i] = new_v
                new_weight[i] = step_w[layer] + velocities[i]
                step_model.state_dict()[layer].copy_(new_weight[i])   
    #         step_model.load_state_dict(new_weight) 

            #本地迭代更新，输出更新后模型，loss值
            local = ModelUpdate(args=args, dataset=dataset_train, idxs=list(dict_users[idx]))  #使用客户端有的数据进行训练
            w, loss = local.train(local_net = copy.deepcopy(step_model).to(args.device), net = copy.deepcopy(step_model).to(args.device))  #获得更新后的模型参数以及损失值
            step_model.load_state_dict(w) 
            acc_test_step, loss_test_step = test_img(step_model, test_dataset, args)

            stepwp_model.load_state_dict(w_best_locals[idx]) 
            acc_test_step_wp, loss_test_step_wp = test_img(stepwp_model, test_dataset, args)
            if acc_test_step_wp<=acc_test_step:
                w_best_locals[idx]=copy.deepcopy(w)

            rand = random.randint(0,99)
            drop_communication = range(drop_rate)
            if rand not in drop_communication:
                client_result.append((idx,acc_test_step))     #设置丢包率

    #         if args.all_clients:
            w_locals[idx] = copy.deepcopy(w)   #每个客户端更新模型参数
    #         else:
    #             w_locals.append(copy.deepcopy(w))

        # update global weights

        dic_client_result = dict(client_result)
        w_glob = FedPSO(w_locals,dic_client_result,k)  #服务器聚合模型获得新的全局模型
        step_model.load_state_dict(w_glob) 
        #再进行一次训练，用元数据
        glob=ModelUpdate(args=args, dataset=dataset, idxs=set(dg_idx))
        w_glob, _ = glob.train(local_net = copy.deepcopy(step_model).to(args.device), net = copy.deepcopy(step_model).to(args.device))


        # copy weight to net_glob
        step_model.load_state_dict(w_glob)  #获取新的全局模型的参数,再下发给客户端

        acc_test, loss_test = test_img(step_model, test_dataset, args)  #开始测试，获得这一轮通信获得模型的准确率与损失值

        if  acc_test>=global_best_score:   #更新全局模型
            global_best_score=acc_test
            global_best_model.load_state_dict(w_glob)

        acc_test_T, loss_test_T = test_img(step_model, dataset_test, args)
        server_evaluate_acc.append(acc_test_T)
        server_evaluate_loss.append(loss_test_T)
        if not args.debug:
            print(f"Round: {iter}")
            print(f"Test accuracy: {acc_test_T}")
            print(f"Test loss: {loss_test_T}")
    test_loss=pd.DataFrame(data=server_evaluate_loss)
    test_loss_filename='FedPSO_Loss_'+str(args.dataset)+'round'+str(args.rounds)+str(args.sampling)+'_'+str(args.classwise)+'.csv'
    test_loss.to_csv(test_loss_filename,header=None) 

    test_acc=pd.DataFrame(data=server_evaluate_acc)
    test_acc_filename='FedPSO_Acc_'+str(args.dataset)+'round'+str(args.rounds)+str(args.sampling)+'_'+str(args.classwise)+'.csv'
    test_acc.to_csv(test_acc_filename,header=None) 

CNN_v1(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:02<03:40,  2.23s/it][A
  2%|▏         | 2/100 [00:04<03:44,  2.29s/it][A
  3%|▎         | 3/100 [00:06<03:41,  2.28s/it][A
  4%|▍         | 4/100 [00:09<03:40,  2.30s/it][A
  5%|▌         | 5/100 [00:11<03:39,  2.31s/it][A
  6%|▌         | 6/100 [00:13<03:37,  2.31s/it][A
  7%|▋         | 7/100 [00:16<03:34,  2.31s/it][A
  8%|▊         | 8/100 [00:18<03:32,  2.31s/it][A
  9%|▉         | 9/100 [00:20<03:30,  2.31s/it][A
 10%|█         | 10/100 [00:23<03:28,  2.32s/it][A
 11%|█         | 11/100 [00:25<03:24,  2.30s/it][A
 12%|█▏        | 12/100 [00:27<03:20,  2.28s/it][A
 13%|█▎        | 13/100 [00:29<03:19,  2.30s/it][A
 14%|█▍        | 14/100 [00:32<03:20,  2.33s/it][A
 15%|█▌        | 15/100 [00:34<03:18,  2.33s/it][A
 16%|█▌        | 16/100 [00:36<03:14,  2.32s/it][A
 17%|█▋        | 17/100 [00:39<03:12,  2.32s/it][A
 18%|█▊        | 18/100 [00:41<03:10,  2.3

Round: 0
Test accuracy: 84.43000030517578
Test loss: 0.5172773820757866



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:02<03:48,  2.31s/it][A
  2%|▏         | 2/100 [00:04<03:45,  2.30s/it][A
  3%|▎         | 3/100 [00:06<03:44,  2.31s/it][A
  4%|▍         | 4/100 [00:09<03:41,  2.31s/it][A
  5%|▌         | 5/100 [00:11<03:40,  2.32s/it][A
  6%|▌         | 6/100 [00:13<03:37,  2.32s/it][A
  7%|▋         | 7/100 [00:16<03:35,  2.32s/it][A
  8%|▊         | 8/100 [00:18<03:32,  2.32s/it][A
  9%|▉         | 9/100 [00:20<03:30,  2.32s/it][A
 10%|█         | 10/100 [00:23<03:27,  2.30s/it][A
 11%|█         | 11/100 [00:25<03:24,  2.29s/it][A
 12%|█▏        | 12/100 [00:27<03:22,  2.30s/it][A
 13%|█▎        | 13/100 [00:29<03:20,  2.30s/it][A
 14%|█▍        | 14/100 [00:32<03:18,  2.31s/it][A
 15%|█▌        | 15/100 [00:34<03:15,  2.31s/it][A
 16%|█▌        | 16/100 [00:37<03:16,  2.33s/it][A
 17%|█▋        | 17/100 [00:39<03:12,  2.32s/it][A
 18%|█▊        | 18/100 [00:41<03:09,  2.31s/it][A
 19%|█▉        | 19/100 [00:4