In [1]:
import time
import random
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.cuda.is_available():

    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))


else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 3090


In [3]:
import pandas as pd
import numpy as np
import random

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

df_train = pd.read_csv("./fine_tuning.csv")

train_data_domain = df_train.domain.values
train_data_label = df_train.label.values
train_data_label = train_data_label.tolist()
train_data_label = [0 if item == 2 else 1 for item in train_data_label]
train_data_label = np.array(train_data_label)

In [4]:
from transformers import BertTokenizer

tokenizer = BertTokenizer(vocab_file="./bert_tokenizer/vocab.txt")

In [5]:
import warnings
warnings.filterwarnings('ignore')
input_ids_train = []
attention_masks_train = []

for sent in train_data_domain:

    encoded_dict = tokenizer.encode_plus(
        sent,                      # Sentence to encode.
        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
        max_length = 64,           # Pad & truncate all sentences.
        pad_to_max_length = True,
        return_attention_mask = True,   # Construct attn. masks.
        return_tensors = 'pt',     # Return pytorch tensors.
    )
    # Add the encoded sentence to the list.
    input_ids_train.append(encoded_dict['input_ids'])

    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks_train.append(encoded_dict['attention_mask'])

# Convert the lists into tensors.
input_ids_train = torch.cat(input_ids_train, dim=0)
attention_masks_train = torch.cat(attention_masks_train, dim=0)
labels_train = torch.tensor(train_data_label)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [6]:
from torch.utils.data import TensorDataset, random_split

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train) # 打包处理，所以数据第一维必须相等
# Calculate the number of samples to include in each set.
train_size = int(0.7 * len(dataset_train))
test_size = len(dataset_train) - train_size

train_dataset, test_dataset = random_split(dataset_train, [train_size, test_size])
print('{:>5,} training samples'.format(train_size))
print('{:>5,} test samples'.format(test_size))

111,998 training samples
48,000 test samples


构造MyModel

In [7]:
from torch import nn

EmbeddingPath = "./FedBert/FedTransformer.pt"
TransformerPath = "./FedBert/FedEmbedding.pt"
num_users = 10
frac = 0.5
local_epochs = 5
epochs = 30

In [8]:
from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,DataCollatorForLanguageModeling,HfArgumentParser,Trainer,TrainingArguments,set_seed,
)
# 自己修改部分配置参数
config_kwargs = {
    "cache_dir": None,
    "revision": 'main',
    "use_auth_token": None,
    #      "hidden_size": 512,
    #     "num_attention_heads": 4,
    "hidden_dropout_prob": 0.2,
    "vocab_size": 1000 # 自己设置词汇大小
}
# 将模型的配置参数载入
config = AutoConfig.from_pretrained('./bert-base-uncased-model/', **config_kwargs)
print(config)
# 载入预训练模型
model = AutoModelForMaskedLM.from_config(
    config=config,
)
model.resize_token_embeddings(config_kwargs["vocab_size"])

embedding = model.bert.embeddings

class Bert_Embedding(nn.Module):
    def __init__(self):
        super(Bert_Embedding, self).__init__()
        self.embeddings = copy.deepcopy(embedding)

    def forward(self, input_ids, attn_mask):
        embedding_output = self.embeddings(input_ids, attn_mask)
        return embedding_output

embedding_model = Bert_Embedding()
embedding_model.load_state_dict(torch.load(EmbeddingPath))

encoder = model.bert.encoder
cls = model.cls

class Bert_Encoder(nn.Module):
    def __init__(self):
        super(Bert_Encoder, self).__init__()
        self.encoder = copy.deepcopy(encoder)
        self.cls = copy.deepcopy(cls)

    def forward(self, embedding_output):
        output_encoder = self.encoder(embedding_output).last_hidden_state
        return output_encoder
encoder_model = Bert_Encoder()
encoder_model.load_state_dict(torch.load(TransformerPath))

from transformers.models.bert.modeling_bert import BertPooler
class Pooler_Config:
    def __init__(self, entries: dict={}):
        for k, v in entries.items():
            if isinstance(v, dict):
                self.__dict__[k] = Pooler_Config(v)
            else:
                self.__dict__[k] = v

config_pooler = {"hidden_size": 768}
config_pooler = Pooler_Config(config_pooler)
pooler = BertPooler(config_pooler)
print(pooler)

class MyModel(nn.Module):
    def __init__(self, hidden_size=768, num_classes=2, freeze_bert=False):
        super(MyModel, self).__init__()
        self.embedding = Bert_Embedding()
        self.encoder = Bert_Encoder()
        self.pooler = copy.deepcopy(pooler)
        if freeze_bert:
            for p in self.embedding.parameters():
                p.requires_grad = False
            for p in self.encoder.parameters():
                p.requires_grad = False
        self.fc = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(hidden_size, num_classes, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input_ids, attn_mask):
        embedding_outputs = self.embedding(input_ids, attn_mask)
        encoder_outputs = self.encoder(embedding_outputs)
        pooler_outputs = self.pooler(encoder_outputs)
        #它代表了一句话的embedding
        logits = self.fc(pooler_outputs)
        return logits

BertConfig {
  "_name_or_path": "./bert-base-uncased-model/",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.2,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.34.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 1000
}

BertPooler(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (activation): Tanh()
)


iid数据分割

In [9]:
def dataset_iid(dataset, num_users):

    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

# train_dataset[:][:][1]
def dirichlet_split_noniid(train_labels, num_users):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''
    alpha = 0.7
    n_classes = 2
    # (K, N) 类别标签分布矩阵X，记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*num_users, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]

    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(num_users)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client，idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).
                                                  astype(int))):
            client_idcs[i] += [idcs]

    dict_users = [np.concatenate(idcs) for idcs in client_idcs]

    return dict_users

In [10]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

dict_user_train = dirichlet_split_noniid(train_dataset[:][:][:][2], num_users)
dict_user_test = dataset_iid(test_dataset, num_users)

# The DataLoader needs to know our batch size for training, so we specify it
# here. For fine-tuning BERT on a specific task, the authors recommend a batch
# size of 16 or 32.
batch_size = 32

In [11]:
net_glob = MyModel()
net_glob.encoder.cls = nn.Sequential()
print(net_glob)

MyModel(
  (embedding): Bert_Embedding(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(1000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (encoder): Bert_Encoder(
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), 

In [12]:
# 训练loss记录
loss_train_collect = {}
# 训练acc记录
acc_train_collect = {}
loss_test_collect = {}
# 测试acc记录
acc_test_collect = {}
# 训练TPR记录
TPR_train_collect = {}
# 测试TPR记录
TPR_test_collect = {}
# 训练FPR记录
FPR_train_collect = {}
# 测试FPR记录
FPR_test_collect = {}
# 训练测试F1-score记录
f1_train_collect = {}
f1_test_collect = {}
# 训练测试AUC记录
AUC_train_collect = {}
AUC_test_collect = {}
# 训练测试ROC曲线记录
ROC_train_collect = {}
ROC_test_collect = {}

local_test = {}
local_testing = {}

loss_collect = []
acc_collect = []
TPR_collect = []
FPR_collect = []
F1_collect = []
AUC_collect = []

count1 = 0
count2 = 0

In [13]:
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

In [14]:
idx_collect = []
l_epoch_check = False
fed_check = False
# Initialization of net_model_server and net_server (server-side model)
net_model = [net_glob for i in range(num_users)]
net_server = copy.deepcopy(net_model[0]).to(device)

In [15]:
from torch.utils.data import DataLoader, Dataset

criterion = nn.BCELoss()

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        domain, mask, label = self.dataset[self.idxs[item]]
        return domain, mask, label


In [16]:
import datetime
from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

def flat_accuracy(preds, labels):
    return np.sum(preds == labels) / len(labels)

def tpr_calculate(preds, labels):
    return recall_score(labels, preds,zero_division=1)

def fpr_calculate(preds, labels):
    conf_matrix = confusion_matrix(labels, preds)
    #print(conf_matrix)
    fp = conf_matrix[0, 1]  # 0 表示负类别，1 表示正类别
    tn = conf_matrix[0, 0]
    fpr = fp / (fp + tn)
    return fpr

def f1_score_calculate(preds, labels):
    return f1_score(labels, preds, zero_division=1)

def AUC_calculate(preds, labels):
    return roc_auc_score(labels, preds)

def roc_curve_calculate(preds, labels):
    return roc_curve(labels, preds)

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [17]:
class Client(object):
    def __init__(self, device, idx, lr, local_epochs, batch_size, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None):
        self.device = device
        self.lr = lr
        self.idx = idx
        self.local_ep = local_epochs
        self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = batch_size, shuffle = True)
        self.ldr_test = DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = batch_size, shuffle= True)
        self.ala = ALA(idx, idxs, criterion, dataset_train, batch_size, 20, 0, lr*4, self.device, 0.1, 10)

    def local_initialization(self, net):
        print("local_initialization!")
        local_model = net_local[self.idx].to(self.device)
        self.ala.adaptive_local_aggregation(net, local_model)

    def train(self, net):
        net.train()
        optimizer_client = torch.optim.Adam(net.parameters(), lr = self.lr)

        TPR_train_collect[self.idx] = []
        FPR_train_collect[self.idx] = []
        f1_train_collect[self.idx] = []
        AUC_train_collect[self.idx] = []
        loss_train_collect[self.idx] = []
        acc_train_collect[self.idx] = []

        epoch_loss = []
        epoch_accuracy = []
        for iter in range(self.local_ep):
            tmp_t0 = time.time()
            batch_loss_train = []
            batch_acc_train = []
            batch_tpr_train = []
            batch_fpr_train = []
            batch_f1_train = []
            batch_auc_train = []
            for batch_idx, (ids, attn_mask, b_labels) in enumerate(self.ldr_train):
                ids, attn_mask, b_labels = ids.to(self.device), attn_mask.to(self.device), b_labels.to(self.device)
                optimizer_client.zero_grad()
                b_labels = b_labels.unsqueeze(1)
                b_labels = b_labels.repeat(1,2)
                for i in range(len(b_labels)):
                    b_labels[i][1] = 1-b_labels[i][0]
                logits = net(ids, attn_mask)
                BCEloss = criterion(logits, b_labels.float())
                BCEloss.backward()
                optimizer_client.step()
                batch_loss_train.append(BCEloss.item())
                logits = logits.detach().cpu().numpy()
                label_ids = b_labels.to('cpu').numpy()
                logits = np.argmax(logits, axis=1).flatten()
                label_ids = np.argmax(label_ids,axis=1).flatten()
                accuracy = flat_accuracy(logits, label_ids)
                tpr = tpr_calculate(logits, label_ids)
                if len(set(label_ids)) == 1:
                    pass
                else:
                    fpr = fpr_calculate(logits, label_ids)
                    batch_fpr_train.append(fpr)
                f1 = f1_score_calculate(logits, label_ids)
                if len(set(label_ids)) == 1:
                    pass
                else:
                    auc = AUC_calculate(logits, label_ids)
                    batch_auc_train.append(auc)
                batch_acc_train.append(accuracy)
                batch_tpr_train.append(tpr)
                batch_f1_train.append(f1)
            elapsed = format_time(time.time()-tmp_t0)
            epoch_avg_loss = sum(batch_loss_train)/len(batch_loss_train)
            epoch_avg_acc = sum(batch_acc_train)/len(batch_acc_train)
            epoch_avg_tpr = sum(batch_tpr_train)/len(batch_tpr_train)
            epoch_avg_fpr = sum(batch_fpr_train)/len(batch_fpr_train)
            epoch_avg_f1 = sum(batch_f1_train)/len(batch_f1_train)
            epoch_avg_auc = sum(batch_auc_train)/len(batch_auc_train)
            epoch_loss.append(sum(batch_loss_train)/len(batch_loss_train))
            epoch_accuracy.append(sum(batch_acc_train)/len(batch_acc_train))
            loss_train_collect[self.idx].append(epoch_avg_loss)
            acc_train_collect[self.idx].append(epoch_avg_acc)
            TPR_train_collect[self.idx].append(epoch_avg_tpr)
            FPR_train_collect[self.idx].append(epoch_avg_fpr)
            f1_train_collect[self.idx].append(epoch_avg_f1)
            AUC_train_collect[self.idx].append(epoch_avg_auc)
            loss_collect.append(epoch_avg_loss)
            acc_collect.append(epoch_avg_acc)
            TPR_collect.append(epoch_avg_tpr)
            FPR_collect.append(epoch_avg_fpr)
            F1_collect.append(epoch_avg_f1)
            AUC_collect.append(epoch_avg_auc)

            print('Client{} Local Train => Local Epoch: {} \tLoss: {:.10f} \tAcc: {:.10f} \tTPR:{:.10f} \tFPR:{:.10f} \tF1:{:.10f} \t AUC:{:.10f}\tTrain cost: {:}'.format(self.idx, iter, epoch_avg_loss, \
                                                                                                                                                                           epoch_avg_acc, epoch_avg_tpr, epoch_avg_fpr, epoch_avg_f1, epoch_avg_auc, elapsed))
        net_glob.load_state_dict(net.state_dict())
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss), sum(epoch_accuracy) / len(epoch_accuracy)

    def evaluate(self, net, ell):
        net.eval()

        with torch.no_grad():
            tmp_t0 = time.time()
            len_batch = len(self.ldr_test)

            batch_acc_test = []
            batch_loss_test = []
            batch_tpr_test = []
            batch_fpr_test = []
            batch_f1_test = []
            batch_auc_test = []
            for batch_idx, (ids, attn_mask, b_labels) in enumerate(self.ldr_test):
                ids, attn_mask, b_labels = ids.to(self.device), attn_mask.to(self.device), b_labels.to(self.device)
                b_labels = b_labels.unsqueeze(1)
                b_labels = b_labels.repeat(1,2)
                for i in range(len(b_labels)):
                    b_labels[i][1] = 1-b_labels[i][0]
                logits = net(ids, attn_mask)
                BCEloss = criterion(logits, b_labels.float())
                batch_loss_test.append(BCEloss.item())
                logits = logits.detach().cpu().numpy()
                label_ids = b_labels.to('cpu').numpy()
                logits = np.argmax(logits, axis=1).flatten()
                label_ids = np.argmax(label_ids,axis=1).flatten()
                accuracy = flat_accuracy(logits, label_ids)
                tpr = tpr_calculate(logits, label_ids)
                if len(set(label_ids)) == 1:
                    pass
                else:
                    fpr = fpr_calculate(logits, label_ids)
                    batch_fpr_test.append(fpr)
                f1 = f1_score_calculate(logits, label_ids)
                if len(set(label_ids)) == 1:
                    pass
                else:
                    auc = AUC_calculate(logits, label_ids)
                    batch_auc_test.append(auc)
                batch_acc_test.append(accuracy)
                batch_tpr_test.append(tpr)
                batch_f1_test.append(f1)
            elapsed = format_time(time.time()-tmp_t0)
            test_avg_loss = sum(batch_loss_test) / len(batch_loss_test)
            test_avg_acc = sum(batch_acc_test) / len(batch_acc_test)
            test_avg_tpr = sum(batch_tpr_test)/len(batch_tpr_test)
            test_avg_fpr = sum(batch_fpr_test) / len(batch_fpr_test)
            test_avg_f1 = sum(batch_f1_test)/len(batch_f1_test)
            test_avg_auc = sum(batch_auc_test)/len(batch_auc_test)
            local_test["loss"].append(test_avg_loss)
            local_test["acc"].append(test_avg_acc)
            local_test["tpr"].append(test_avg_tpr)
            local_test["fpr"].append(test_avg_fpr)
            local_test["f1"].append(test_avg_f1)
            local_test["auc"].append(test_avg_auc)
            print('Client{} Test =>                 \tLoss: {:.10f} \tAcc: {:.10f} \tTPR:{:.10f} \tFPR:{:.10f} \tF1:{:.10f} \tAUC:{:.10f} \ttest cost: {:}'.format(self.idx, test_avg_loss, test_avg_acc, test_avg_tpr, test_avg_fpr, test_avg_f1, test_avg_auc, elapsed))

        return test_avg_loss, test_avg_acc, test_avg_tpr, test_avg_fpr, test_avg_f1, test_avg_auc

In [18]:
class ALA:
    def __init__(self,
                 cid: int,
                 idxs,
                 loss,
                 train_data: TensorDataset,
                 batch_size: int,
                 rand_percent: int,
                 layer_idx: int = 0,
                 eta: float = 1.0,
                 device: str = 'cpu',
                 threshold: float = 0.1,
                 num_pre_loss: int = 10) -> None:
        """
        Initialize ALA module

        Args:
            cid: Client ID.
            loss: The loss function.
            train_data: The reference of the local training data.
            batch_size: Weight learning batch size.
            rand_percent: The percent of the local training data to sample.
            layer_idx: Control the weight range. By default, all the layers are selected. Default: 0
            eta: Weight learning rate. Default: 1.0
            device: Using cuda or cpu. Default: 'cpu'
            threshold: Train the weight until the standard deviation of the recorded losses is less than a given threshold. Default: 0.1
            num_pre_loss: The number of the recorded losses to be considered to calculate the standard deviation. Default: 10

        Returns:
            None.
        """

        self.cid = cid
        self.idxs = idxs
        self.loss = loss
        self.train_data = train_data
        self.batch_size = batch_size
        self.rand_percent = rand_percent
        self.layer_idx = layer_idx
        self.eta = eta
        self.threshold = threshold
        self.num_pre_loss = num_pre_loss
        self.device = device

        self.weights = None # Learnable local aggregation weights.
        self.start_phase = True


    def adaptive_local_aggregation(self,
                                   global_model: nn.Module,
                                   local_model: nn.Module) -> None:
        """
        Generates the Dataloader for the randomly sampled local training data and
        preserves the lower layers of the update.

        Args:
            global_model: The received global/aggregated model.
            local_model: The trained local model.

        Returns:
            None.
        """

        # randomly sample partial local training data
        rand_ratio = self.rand_percent / 100
        rand_loader = DataLoader(DatasetSplit(self.train_data, self.idxs), self.batch_size, drop_last=True, shuffle=True)
        rand_num = int(rand_ratio*len(rand_loader))


        # obtain the references of the parameters
        params_g = list(global_model.parameters())
        params = list(local_model.parameters())

        # deactivate ALA at the 1st communication iteration
        if torch.sum(params_g[-1] - params[-1]) == 0:
            print("deactivate ALA")
            return

        # preserve all the updates in the lower layers
        for param, param_g in zip(params[:-self.layer_idx], params_g[:-self.layer_idx]):
            param.data = param_g.data.clone()


        # temp local model only for weight learning
        model_t = copy.deepcopy(local_model)
        params_t = list(model_t.parameters())
        params_t[-1].requires_grad_()

        # only consider higher layers
        params_p = params[-self.layer_idx:]
        params_gp = params_g[-self.layer_idx:]
        params_tp = params_t[-self.layer_idx:]


        # used to obtain the gradient of higher layers
        # no need to use optimizer.step(), so lr=0
        optimizer = torch.optim.Adam(params_tp, lr = lr)

        # initialize the weight to all ones in the beginning
        if self.weights == None:
            self.weights = [torch.ones_like(param.data).to(self.device) for param in params_p]

        # initialize the higher layers in the temp local model
        for param_t, param, param_g, weight in zip(params_tp, params_p, params_gp,
                                                   self.weights):
            param_t.data = param + (param_g - param) * weight

        # weight learning
        losses = []  # record losses
        losses_round = []
        cnt = 0  # weight training iteration counter
        while True:
            for batch_idx, (x, m , y) in enumerate(rand_loader):
                if batch_idx >= rand_num:
                    break
                x = x.to(self.device)
                m = m.to(self.device)
                y = y.to(self.device)
                y = y.unsqueeze(1)
                y = y.repeat(1,2)
                for i in range(len(y)):
                    y[i][1] = 1-y[i][0]
                optimizer.zero_grad()
                output = model_t(x, m)
                loss_value = self.loss(output, y.float()) # modify according to the local objective
                losses.append(loss_value.item())
                loss_value.backward()

                # update weight in this batch
                for param_t, param, param_g, weight in zip(params_tp, params_p,
                                                           params_gp, self.weights):
                    #print("param_t.grad:",param_t.requires_grad)
                    #print("param_g - param:",type(param_g - param))
                    weight.data = torch.clamp(weight - self.eta * (param_t.grad * (param_g - param)), 0, 1)

                # update temp local model in this batch
                for param_t, param, param_g, weight in zip(params_tp, params_p,
                                                           params_gp, self.weights):
                    param_t.data = param + (param_g - param) * weight

            losses_round.append(sum(losses) / len(losses))
            cnt += 1

            # only train one epoch in the subsequent iterations
            if not self.start_phase:
                break

            # train the weight until convergence
            if len(losses_round) > self.num_pre_loss or np.std(losses[-self.num_pre_loss:]) < self.threshold:
                print('Client:', self.cid, '\tStd:', np.std(losses[-self.num_pre_loss:]),
                      '\tALA epochs:', cnt)
                break

        self.start_phase = False

        # obtain initialized local model
        for param, param_t in zip(params_p, params_tp):
            param.data = param_t.data.clone()
        print("Client {}: Local Initial ALA epochs: {} Loss: {:.20f}".format(self.cid, cnt, losses_round[-1]))

In [19]:
print("Train and Test Begin!")
net_glob.train()
w_net_glob = net_glob.state_dict()
t0 = time.time()
net_local = [copy.deepcopy(net_glob) for i in range(num_users)]

lr = 2e-6

local_test["loss"] = []
local_test["acc"] = []
local_test["tpr"] = []
local_test["fpr"] = []
local_test["f1"] = []
local_test["auc"] = []

local_testing["loss"] = []
local_testing["acc"] = []
local_testing["tpr"] = []
local_testing["fpr"] = []
local_testing["f1"] = []
local_testing["auc"] = []

for iter in range(epochs):
    print("============== Round {}:  =============".format(iter))
    idx_collect = []
    m = max(int(frac * num_users) ,1)
    idxs_users = np.random.choice(range(num_users), m, replace = False)
    w_locals_client = []

    loss_list = []
    acc_list = []
    tpr_list = []
    fpr_list = []
    f1_list = []
    auc_list = []

    for idx in idxs_users:
        local = Client(device, idx, lr, local_epochs, batch_size, train_dataset, test_dataset, dict_user_train[idx], dict_user_test[idx])
        local.local_initialization(net = copy.deepcopy(net_glob).to(device))
        w_client, client_loss, client_acc = local.train(net = copy.deepcopy(net_local[idx]).to(device))
        w_locals_client.append(copy.deepcopy(w_client))
        loss, acc, tpr, fpr, f1, auc = local.evaluate(net = copy.deepcopy(net_glob).to(device), ell=iter)

        loss_list.append(loss)
        acc_list.append(acc)
        tpr_list.append(tpr)
        fpr_list.append(fpr)
        f1_list.append(f1)
        auc_list.append(auc)

    local_testing["loss"].append(sum(loss_list)/len(loss_list))
    local_testing["acc"].append(sum(acc_list)/len(acc_list))
    local_testing["tpr"].append(sum(tpr_list)/len(tpr_list))
    local_testing["fpr"].append(sum(fpr_list)/len(fpr_list))
    local_testing["f1"].append(sum(f1_list)/len(f1_list))
    local_testing["auc"].append(sum(auc_list)/len(auc_list))
    print("Test =>                 \tLoss: {:.10f} \tAcc: {:.10f} \tTPR:{:.10f} \tFPR:{:.10f} \tF1:{:.10f} \tAUC:{:.10f}".format(sum(loss_list)/len(loss_list), sum(acc_list)/len(acc_list), sum(tpr_list)/len(tpr_list), sum(fpr_list)/len(fpr_list), sum(f1_list)/len(f1_list), sum(auc_list)/len(auc_list)  ))

    print("-----------------------------------------------------------")
    print("-------------- FedServer: Federation process  -------------")
    print("-----------------------------------------------------------")
    w_net_glob = FedAvg(w_locals_client)
    net_glob.load_state_dict(w_net_glob)
elapsed = format_time(time.time()-t0)
print("Training and Test completed! total time cost: {:}".format(elapsed))

Train and Test Begin!
local_initialization!
deactivate ALA
Client8 Local Train => Local Epoch: 0 	Loss: 0.4946515006 	Acc: 0.7845784024 	TPR:0.9554121697 	FPR:0.7764274251 	F1:0.8689496695 	 AUC:0.5893596704	Train cost: 0:00:32
Client8 Local Train => Local Epoch: 1 	Loss: 0.4707483554 	Acc: 0.8028846154 	TPR:0.9532991431 	FPR:0.7089138277 	F1:0.8795974830 	 AUC:0.6221233686	Train cost: 0:00:32
Client8 Local Train => Local Epoch: 2 	Loss: 0.4134785727 	Acc: 0.8381102071 	TPR:0.9658915116 	FPR:0.5922620782 	F1:0.9000396764 	 AUC:0.6867597063	Train cost: 0:00:35
Client8 Local Train => Local Epoch: 3 	Loss: 0.3652202301 	Acc: 0.8578032544 	TPR:0.9597872379 	FPR:0.4843348959 	F1:0.9104992223 	 AUC:0.7377261710	Train cost: 0:00:35
Client8 Local Train => Local Epoch: 4 	Loss: 0.3129755075 	Acc: 0.8798076923 	TPR:0.9617996245 	FPR:0.3984608616 	F1:0.9233734644 	 AUC:0.7816127044	Train cost: 0:00:34
Client8 Test =>                 	Loss: 0.3871208485 	Acc: 0.8304166667 	TPR:0.8959629931 	FPR:0.

In [20]:
idx_collect = [i for i in range(num_users)]
print("============= Final Result =============")
for idx in idx_collect:
    loss_test_collect[idx] = []
    acc_test_collect[idx] = []
    TPR_test_collect[idx] = []
    FPR_test_collect[idx] = []
    f1_test_collect[idx] = []
    AUC_test_collect[idx] = []
    local = Client(device, idx, lr, local_epochs, batch_size, train_dataset, test_dataset, dict_user_train[idx], dict_user_test[idx])
    local.local_initialization(net = copy.deepcopy(net_glob).to(device))
    loss, acc, tpr, fpr, f1, auc = local.evaluate(net = copy.deepcopy(net_local[idx]).to(device), ell=0)
    loss_test_collect[idx].append(loss)
    acc_test_collect[idx].append(acc)
    TPR_test_collect[idx].append(tpr)
    FPR_test_collect[idx].append(fpr)
    f1_test_collect[idx].append(f1)
    AUC_test_collect[idx].append(auc)

local_initialization!
Client: 0 	Std: 0.026846555985805546 	ALA epochs: 1
Client 0: Local Initial ALA epochs: 1 Loss: 0.02836113898083567758
Client0 Test =>                 	Loss: 0.1971132714 	Acc: 0.9500000000 	TPR:0.9522314099 	FPR:0.0534716299 	F1:0.9497142165 	AUC:0.9493798900 	test cost: 0:00:05
local_initialization!
Client: 1 	Std: 0.09939059847416308 	ALA epochs: 6
Client 1: Local Initial ALA epochs: 6 Loss: 0.18544305231854177340
Client1 Test =>                 	Loss: 0.2159932375 	Acc: 0.9487500000 	TPR:0.9461497630 	FPR:0.0514951208 	F1:0.9448570736 	AUC:0.9473273211 	test cost: 0:00:05
local_initialization!
Client: 2 	Std: 0.03564530824012863 	ALA epochs: 1
Client 2: Local Initial ALA epochs: 1 Loss: 0.05153172613813005593
Client2 Test =>                 	Loss: 0.2127031225 	Acc: 0.9456250000 	TPR:0.9546251588 	FPR:0.0635646367 	F1:0.9442100739 	AUC:0.9455302610 	test cost: 0:00:05
local_initialization!
Client: 3 	Std: 0.056419750788617475 	ALA epochs: 1
Client 3: Local Ini

In [21]:
import xlwt
f = xlwt.Workbook('encoding = utf-8')
sheet1 = f.add_sheet('sheet1',cell_overwrite_ok=True)
for i in range(len(loss_collect)):
    sheet1.write(i+1,0,loss_collect[i]) #写入数据参数对应 行, 列, 值
for i in range(len(acc_collect)):
    sheet1.write(i+1,1,acc_collect[i])
for i in range(len(TPR_collect)):
    sheet1.write(i+1,2,TPR_collect[i])
for i in range(len(FPR_collect)):
    sheet1.write(i+1,3,FPR_collect[i])
for i in range(len(F1_collect)):
    sheet1.write(i+1,4,F1_collect[i])
for i in range(len(AUC_collect)):
    sheet1.write(i+1,5,AUC_collect[i])

f.save('result.xls')#保存.xls到当前工作目录

In [22]:
for i in range(num_users):
    f = xlwt.Workbook('encoding = utf-8')
    sheet1 = f.add_sheet('sheet1', cell_overwrite_ok=True)
    for j in range(len(loss_train_collect[i])):
        sheet1.write(j+1, 0, loss_train_collect[i][j])
    for j in range(len(acc_train_collect[i])):
        sheet1.write(j+1, 1, acc_train_collect[i][j])
    for j in range(len(TPR_train_collect[i])):
        sheet1.write(j+1, 2, TPR_train_collect[i][j])
    for j in range(len(FPR_train_collect[i])):
        sheet1.write(j+1, 3, FPR_train_collect[i][j])
    for j in range(len(f1_train_collect[i])):
        sheet1.write(j+1, 4, f1_train_collect[i][j])
    for j in range(len(AUC_train_collect[i])):
        sheet1.write(j+1, 5, AUC_train_collect[i][j])

    f.save('result_client{:}.xls'.format(i))

In [23]:
for i in range(num_users):
    f = xlwt.Workbook('encoding = utf-8')
    sheet1 = f.add_sheet('sheet1', cell_overwrite_ok=True)
    for j in range(len(loss_test_collect[i])):
        sheet1.write(j+1, 0, loss_test_collect[i][j])
    for j in range(len(acc_test_collect[i])):
        sheet1.write(j+1, 1, acc_test_collect[i][j])
    for j in range(len(TPR_test_collect[i])):
        sheet1.write(j+1, 2, TPR_test_collect[i][j])
    for j in range(len(FPR_test_collect[i])):
        sheet1.write(j+1, 3, FPR_test_collect[i][j])
    for j in range(len(f1_test_collect[i])):
        sheet1.write(j+1, 4, f1_test_collect[i][j])
    for j in range(len(AUC_test_collect[i])):
        sheet1.write(j+1, 5, AUC_test_collect[i][j])

    f.save('result_test_client{:}.xls'.format(i))

In [24]:
import xlwt
f = xlwt.Workbook('encoding = utf-8')
sheet1 = f.add_sheet('sheet1',cell_overwrite_ok=True)
for i in range(len(local_test["loss"])):
    sheet1.write(i+1,0,local_test["loss"][i])
for i in range(len(local_test["acc"])):
    sheet1.write(i+1,1,local_test["acc"][i])
for i in range(len(local_test["tpr"])):
    sheet1.write(i+1,2,local_test["tpr"][i])
for i in range(len(local_test["fpr"])):
    sheet1.write(i+1,3,local_test["fpr"][i])
for i in range(len(local_test["f1"])):
    sheet1.write(i+1,4,local_test["f1"][i])
for i in range(len(local_test["auc"])):
    sheet1.write(i+1,5,local_test["auc"][i])

f.save('Local_Test.xls')

In [25]:
import xlwt
f = xlwt.Workbook('encoding = utf-8')
sheet1 = f.add_sheet('sheet1',cell_overwrite_ok=True)
for i in range(len(local_testing["loss"])):
    sheet1.write(i+1,0,local_testing["loss"][i])
for i in range(len(local_testing["acc"])):
    sheet1.write(i+1,1,local_testing["acc"][i])
for i in range(len(local_testing["tpr"])):
    sheet1.write(i+1,2,local_testing["tpr"][i])
for i in range(len(local_testing["fpr"])):
    sheet1.write(i+1,3,local_testing["fpr"][i])
for i in range(len(local_testing["f1"])):
    sheet1.write(i+1,4,local_testing["f1"][i])
for i in range(len(local_testing["auc"])):
    sheet1.write(i+1,5,local_testing["auc"][i])

f.save('Local_Testing.xls')