In [None]:
import time
import torch
from torch import nn, optim
import sys
import random
import os
from collections import OrderedDict

torch.cuda.is_available()

In [None]:
global_a_s=list(range(20000))
random.shuffle(global_a_s)
global_a_count={i:0 for i in ['DNS', 'LDAP', 'MSSQL', 'NetBIOS', 'NTP', 'Portmap', 'SNMP', 'SSDP', 'Syn', 'TFTP', 'UDPLag', 'UDP']}
global_n_s=torch.randperm(73500)
global_n_count=0

global_a_oe_count={i:0 for i in ['DNS', 'LDAP', 'MSSQL', 'NetBIOS', 'NTP', 'Portmap', 'SNMP', 'SSDP', 'Syn', 'TFTP', 'UDPLag', 'UDP']}

def sample_anomaly_from_mix(type_a,num_a):
    global global_a_s
    global global_a_count
    o_data_x=torch.load("./../../dataset/processed_CICDDoS2019/result5/attack_mix/"+type_a+".pt")
    o_index=global_a_s[global_a_count[type_a]:global_a_count[type_a]+num_a]
    global_a_count[type_a]+=num_a
    return o_data_x[o_index]

def sample_anomaly_from_oe(type_a,num_a):
    global global_a_oe_count
    assert num_a<=30
    o_data_x=torch.load("./../../dataset/processed_CICDDoS2019/result5/oe/"+type_a+"_oe.pt")
    t=o_data_x[global_a_oe_count[type_a]:global_a_oe_count[type_a]+num_a]
    global_a_oe_count[type_a]+=num_a
    return t

def load_normal(num_n):
    # 56651+25707+13273
    global global_n_s
    global global_n_count
    n_data_x=torch.load("./../../dataset/processed_CICDDoS2019/result5/BENIGN/BENIGN_train.pt")
    
    global_n_count+=num_n
    return n_data_x[global_n_s[global_n_count-num_n:global_n_count]]

In [None]:
def load_client_data(dirty_d,clean_d=(0,{})):
    # 返回的result包含以下内容
    # result[0][0]是污染数据集的输入
    # result[0][1]是污染数据集的实际标签(仅作为辅助信息 训练时不应使用)
    # result[1][0]是小型干净数据集的异常输入部分
    # result[1][1]是小型干净数据集的正常输入部分
    # result[2]是数据总数
    dirty_normal_num=dirty_d[0]
    clean_normal_num=clean_d[0]
    normal_d=load_normal(dirty_normal_num+clean_normal_num)
    dirty_data_x=normal_d[:dirty_normal_num]
    clean_data_x=normal_d[dirty_normal_num:]
    
    total_dirty_anomaly_num=0
    for k in dirty_d[1].keys():
        dirty_data_x=torch.cat((dirty_data_x,sample_anomaly_from_mix(k,dirty_d[1][k])),dim=0)
        total_dirty_anomaly_num+=dirty_d[1][k]
    dirty_data_y=torch.cat((torch.zeros(dirty_normal_num),torch.ones(total_dirty_anomaly_num)),dim=0).float()
    
    tc=torch.zeros(0,66).float()
    total_clean_anomaly_num=0
    for k in clean_d[1].keys():
        tc=torch.cat((tc,sample_anomaly_from_oe(k,clean_d[1][k])),dim=0)
        total_clean_anomaly_num+=clean_d[1][k]

    return [[dirty_data_x,dirty_data_y],[tc,clean_data_x],
             dirty_normal_num+clean_normal_num+total_dirty_anomaly_num+total_clean_anomaly_num]

In [None]:
class TabTransformNet(nn.Module):
    def __init__(self):
        super(TabTransformNet, self).__init__()
        self.net=nn.Sequential(
            # nn.Linear(79, 64,bias=False),
            nn.Linear(66, 48),
            nn.ReLU(),
            nn.Linear(48, 32),
            nn.ReLU(),
            nn.Linear(32, 66),
        )

    def forward(self, x):
        out = self.net(x)
        return out

class NeuTraL(nn.Module):
    def __init__(self, num_trans):
        super(NeuTraL, self).__init__()
        self.enc=nn.Sequential(
            nn.Linear(66, 48),
            nn.ReLU(),
            nn.Linear(48, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
        )
        self.num_trans=num_trans
        self.trans = nn.ModuleList([TabTransformNet() for _ in range(self.num_trans)])

    def forward(self,x):
        x_T = torch.zeros(x.shape[0],self.num_trans,x.shape[-1]).to(x)
        for i in range(self.num_trans):
            mask = self.trans[i](x)
            mask = torch.sigmoid(mask)
            x_T[:, i] = mask * x
        x_cat = torch.cat([x.unsqueeze(1),x_T],1)  # x_cat形状为[batch_size,self.num_trans+1,79]
        zs = self.enc(x_cat.reshape(-1,x.shape[-1]))
        z_dim=zs.shape[-1]
        return zs.reshape(x.shape[0],self.num_trans+1,z_dim)

In [None]:
def NeuTraL_loss(z,device,y,temp=0.1):
    z=z.to(device)
    z = nn.functional.normalize(z, p=2, dim=-1)
    # 上面这一行 z本来形状为[batch_size,num_trans+1,z_dim] 处理后 每个z[a,b,:]向量都分别被变为长度为1
    
    z_ori = z[:, 0]  # n,z
    z_trans = z[:, 1:]  # n,k-1, z
    batch_size, num_trans, z_dim = z.shape  # 此处的num_trans实际上是num_trans+1

    sim_matrix = torch.exp(torch.matmul(z, z.permute(0, 2, 1) / temp))  # n,k,k
    
    # 下面这一行 减号前的一项形状为[batch_size,num_trans+1,z_dim] 减号后的一项形状为[1,num_trans+1,z_dim]
    # torch.eye()对角矩阵
    mask = (torch.ones_like(sim_matrix).to(device) - torch.eye(num_trans).unsqueeze(0).to(device)).bool()
    # 下面这一行 masked_select返回一维tensor
    sim_matrix = sim_matrix.masked_select(mask).view(batch_size, num_trans, -1)
    trans_matrix = sim_matrix[:, 1:].sum(-1)  # n,k-1

    pos_sim = torch.exp(torch.sum(z_trans * z_ori.unsqueeze(1), -1) / temp) # n,k-1
    K = num_trans - 1
    scale = 1 / abs(K*float(torch.log(torch.tensor(1.0 / K))))

    # loss_tensor = (torch.log(trans_matrix+1e-7) - torch.log(pos_sim+1e-7)) * scale
    p_k=pos_sim/(trans_matrix+1e-7) # n,k-1
    # y原本的形状是[n]
    y=y.view(-1,1)
    p_k_y=(1-y)*p_k+y*(1-p_k)
    l_tensor =-torch.log(p_k_y+1e-7)
    return l_tensor.sum(1)

def e_loss(x,net,device):
    # 测试时 仅使用损失函数的一项 需要传输全为0的y作为参数
    y=torch.zeros(x.shape[0]).float()
    y=y.to(device)
    net = net.to(device)
    x=x.to(device)
    with torch.no_grad():
        z=net(x)
        l=NeuTraL_loss(z,device,y)
    return l

In [None]:
def copy_model(m1, m2):
    # 更改m2的模型参数 赋值为模型m1的参数
    d = OrderedDict()
    for k in m1.state_dict().keys():
        d[k] = m1.state_dict()[k].clone()
    m2.load_state_dict(d)

def FedAvg(model_list,data_num):
    # 输入一个模型列表 返回聚合结果(字典)
    total_num = sum(data_num)
    d = OrderedDict()
    for k in model_list[0].state_dict().keys():
        d[k] = model_list[0].state_dict()[k].cpu().clone()*(data_num[0]/total_num)
    for i in range(1, len(model_list)):
        for k in model_list[0].state_dict().keys():
            d[k] += model_list[i].state_dict()[k].cpu().clone()*(data_num[i]/total_num)
    return d

In [None]:
def calculate_gradient(m1,m2):
    # 返回一个m2-m1的OrderedDict
    d = OrderedDict()
    for k in m1.state_dict().keys():
        d[k] = m2.state_dict()[k].clone()-m1.state_dict()[k].clone()
    return d

def calculate_gradient_plus(m1,m2):
    # 返回一个m1+m2的OrderedDict
    d = OrderedDict()
    for k in m1.state_dict().keys():
        d[k] = m1.state_dict()[k].clone()+m2.state_dict()[k].clone()
    return d

def model_cosine_similarity(a,b):
    global device
    
    ta=torch.tensor([]).to(device)
    tb=torch.tensor([]).to(device)
    for k in a.state_dict().keys():
        
        ta=torch.cat((ta,a.state_dict()[k].clone().view(-1)),dim=0)
        tb=torch.cat((tb,b.state_dict()[k].clone().view(-1)),dim=0)
    s=torch.cosine_similarity(ta,tb,dim=0)
    return s.item()

def model_distance_square(a,b):
    global device
    
    t=torch.tensor([]).to(device)
    for k in a.state_dict().keys():

        dk=a.state_dict()[k].clone()-b.state_dict()[k].clone()
        t=torch.cat((t, dk.view(-1)),dim=0)
    square = t * t
    return square.sum().item()

# ---------------------------------------------------
def calculate_model_length_square(m):
    global device
    # 计算模型的长度的平方
    t = torch.tensor([]).to(device)
    for i in m.state_dict().keys():
        t = torch.cat((t, m.state_dict()[i].view(-1)), 0)
    square = t*t
    return square.sum().item()

In [None]:
# 定义了3个集群的7个客户端数据配置，前6个是下层客户端，最后一个是上层客户端。
# 每个客户端指定了:（污染数据集的正常样本数量，污染攻击样本类型及数量），（小型干净数据集的正常样本数量，攻击样本类型及数量）
# ['DNS', 'LDAP', 'MSSQL', 'NetBIOS', 'NTP', 'Portmap', 'SNMP', 'SSDP', 'Syn', 'TFTP', 'UDPLag', 'UDP']
dirty_num=0

cluster1=[
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),
     # (30,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
     (0,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
    ),
]

cluster2=[
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),
     # (30,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
     (0,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
    ),
]

cluster3=[
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-dirty_num,{'SNMP':0,'Portmap':0,'UDPLag': 0}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),(0,{})),
    ((3500-0,{}),
     # (30,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
     (0,{'NetBIOS': 10,'Syn':10,'LDAP': 10})
    ),
]

# 'SNMP':0,'Portmap':0,'UDPLag': 0
# 'SNMP':24,'Portmap':23,'UDPLag': 23
# 'SNMP':47,'Portmap':47,'UDPLag': 46
# 'SNMP':70,'Portmap':70,'UDPLag': 70
# 'SNMP':94,'Portmap':93,'UDPLag': 93
# 'SNMP':117,'Portmap':117,'UDPLag': 116

len(cluster1),len(cluster2),len(cluster3)

In [None]:
client_data=[
    [load_client_data(cluster1[i][0],cluster1[i][1]) for i in range(len(cluster1))],
    [load_client_data(cluster2[i][0],cluster2[i][1]) for i in range(len(cluster2))],
    [load_client_data(cluster3[i][0],cluster3[i][1]) for i in range(len(cluster3))],
]

for p_cd in client_data:
    print([(i[0][0].shape,i[0][1].shape,i[1][0].shape,i[1][1].shape) for i in p_cd])

In [None]:
t_num=4
local_model_list=[
    [NeuTraL(t_num) for i in range(len(cluster1))],
    [NeuTraL(t_num) for i in range(len(cluster2))],
    [NeuTraL(t_num) for i in range(len(cluster3))],
]
cluster_global_model_list=[NeuTraL(t_num) for i in range(len(client_data))]
global_model=NeuTraL(t_num)

lr=0.01
weight_decay=0
num_epochs=800
batch_size=1024

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train_single(model,X,y, lr,weight_decay, device):
    model=model.to(device)
    X=X.to(device)
    y=y.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    z = model(X)
    l_mean = NeuTraL_loss(z, device,y).mean()
    optimizer.zero_grad()
    l_mean.backward()
    optimizer.step()

def train_cluster(local_model_list,cluster_global_model,
                  cluster_data, lr,weight_decay, device,batch_size):
    # 训练某个cluster一个epoch
    num_client=len(local_model_list)
    cluster_global_model=cluster_global_model.to(device)
    for i in range(num_client-1):
        local_model_list[i]=local_model_list[i].to(device)
        copy_model(cluster_global_model, local_model_list[i])
        
        # 训练普通client
        s=random.sample(range(cluster_data[i][0][1].shape[0]), k=batch_size)
        X=cluster_data[i][0][0][s]
        y=torch.zeros(batch_size).float()
        train_single(local_model_list[i],X,y,lr,weight_decay, device)

    # 训练中间层client 即local_model_list的最后一个
    local_model_list[-1]=local_model_list[-1].to(device)
    copy_model(cluster_global_model, local_model_list[-1])
    s=random.sample(range(cluster_data[-1][0][1].shape[0]), k=batch_size-cluster_data[-1][1][0].shape[0])
    X=cluster_data[-1][0][0][s]
    y=torch.zeros(batch_size-cluster_data[-1][1][0].shape[0]).float()
    
    X=torch.cat((X,cluster_data[-1][1][0]),dim=0)
    y=torch.cat((y,torch.ones(cluster_data[-1][1][0].shape[0])),dim=0).float()
    
    train_single(local_model_list[-1],X,y,lr,weight_decay, device)

In [None]:
def train_hierarchical(local_model_list,cluster_global_model_list,global_model,
                       client_data,lr,weight_decay, device,batch_size,
                       num_epochs,agg_f,print_e=5):
    print("training on", device)
    start=time.time()
    num_cluster=len(client_data)
    
    global_model=global_model.to(device)
    for j in range(num_cluster):
        cluster_global_model_list[j]=cluster_global_model_list[j].to(device)

    for i in range(num_epochs):
        for j in range(num_cluster):
            # 将global_model复制到当前cluster
            copy_model(global_model, cluster_global_model_list[j])
            # 逐一训练当前cluster的每个client
            train_cluster(local_model_list[j],cluster_global_model_list[j],
                          client_data[j],lr,weight_decay, device,batch_size)
        for j in range(num_cluster):
            # 当前cluster完成内部聚合 使用agg_f
            c_g_d=agg_f(local_model_list[j],global_model,i,j)
            cluster_global_model_list[j].load_state_dict(c_g_d)

        # 完成cluster间聚合 使用FedAvg
        global_d=FedAvg(cluster_global_model_list,[1]*num_cluster)
        global_model.load_state_dict(global_d)

        if (i+1)%print_e==0:
            print("epoch",i+1)

In [None]:
def Trimmed_Mean_3(model_list,global_model,epoch,cluster_id):
    c=3

    num = len(model_list)
    result_d = OrderedDict()
    list_keys = list(model_list[0].state_dict().keys())
    for i in list_keys:
        shape_i = model_list[0].state_dict()[i].shape
        tensor_i = [model_list[k].state_dict()[i].view(1, -1).cpu() for k in range(num)]
        l1 = tensor_i[0].shape[1]
        tensor_i_cat = torch.tensor([]).float().view(0, l1)
        for t_i in tensor_i:
            tensor_i_cat = torch.cat((tensor_i_cat, t_i), 0)
        tensor_i_cat = torch.sort(tensor_i_cat, dim=0)
        # 上一行返回的tensor_i_cat包含两项 第一项是排序结果 第二项是对应的下标(在这里用不到)
        tensor_i_cat = tensor_i_cat[0]
        result_i = torch.tensor([0] * l1).float()
        for j in range(l1):
            mean = tensor_i_cat[c:num - c, j].sum().item() / (num - 2 * c)
            result_i[j] = mean
        result_d[i] = result_i.view(shape_i)
    return result_d

def Krum_3(model_list,global_model,epoch,cluster_id):
    c=3
    
    num = len(model_list)
    distance = [[0] * num for i in range(num)]
    for i in range(num - 1):
        for j in range(i + 1, num):
            t = model_distance_square(model_list[i], model_list[j])
            distance[i][j] = t
            distance[j][i] = t
    min_score = float("inf")
    min_score_i = -1
    for i in range(num):
        distance[i].sort()
        score_i = sum(distance[i][1:num - c - 1])
        if score_i < min_score:
            min_score = score_i
            min_score_i = i
    
    print(min_score_i,end='\t\t')
    result_d = OrderedDict()
    for k in model_list[min_score_i].state_dict().keys():
        result_d[k] = model_list[min_score_i].state_dict()[k].clone()
    return result_d

def FLTrust(model_list,global_model,epoch,cluster_id):
    # 使用中间层参与方的小型干净数据集里的正常流量训练一个梯度作为基准 其不纳入聚合
    global client_data
    global dirty_num
    t_m=NeuTraL(t_num).to(device)
    copy_model(global_model, t_m)
    if dirty_num==0:
        # 训练集里的正常流量只有73500条 dirty_num=0即污染比例为0时只够正常数据集使用
        # 故此时小型干净数据集里的数据使用30条正常数据集的代替 (小型干净数据集的正常流量数据只有FLTrust要用)
        X=client_data[cluster_id][-1][0][0][-30:]
        print('dirty_num=0!!!!!!!!!!!!!!!',X.shape)
    else:
        X=client_data[cluster_id][-1][1][1]
    y=torch.zeros(X.shape[0]).float()
    train_single(t_m,X,y,lr,weight_decay, device)
    d=calculate_gradient(global_model,t_m)
    t_m.load_state_dict(d)

    num=len(model_list)
    g_list=[NeuTraL(t_num) for i in range(num)]
    for i in range(num):
        g_list[i]=g_list[i].to(device)
        d0=calculate_gradient(global_model,model_list[i])
        g_list[i].load_state_dict(d0)
    cs_l=[model_cosine_similarity(g_list[i],t_m) for i in range(num)]
    TS=[max(0,cs_l[i]) for i in range(num)]
    sum_TS=sum(TS)
    g_length=[calculate_model_length_square(g_list[i])**0.5 for i in range(num)]
    g0_l=calculate_model_length_square(t_m)**0.5
    
    d=OrderedDict()
    for k in global_model.state_dict().keys():
        t=g_list[0].state_dict()[k].clone()
        d[k]=t*(g0_l*TS[0]/g_length[0])/sum_TS
        for j in range(1,num):
            t=g_list[j].state_dict()[k].clone()
            d[k]+=t*(g0_l*TS[j]/g_length[j])/sum_TS

        # 以上的d是一个模型梯度 需要变回模型
        d[k]+=global_model.state_dict()[k].clone()

    print(cs_l)
    return d

In [None]:
def AVG(model_list,global_model,epoch,cluster_id):
    # 不做任何防御措施
    return FedAvg(model_list,data_num=[1]*len(model_list))

global_c_s_info=[[2,-2],[2,-2],[2,-2],0,-1]

In [None]:
def c_s_based_type3(model_list,global_model,epoch,cluster_id):
    global client_data
    global global_c_s_info
    global local_model_list
    
    def get_th_type3(n,a):
        u,v=torch.std_mean(torch.tensor(a))
        return float(v)+2*float(u)

    num_cluster=len(client_data) #获取客户端数据的数量，即客户端的个数（通常是聚合的参与方）

    if global_c_s_info[-1]<epoch:
        # 模拟多个中间层client交换余弦相似度信息
        for i in range(num_cluster):
            t_base=NeuTraL(t_num).to(device)
            copy_model(local_model_list[i][-1], t_base)
            d=calculate_gradient(global_model,t_base)   #算该模型与全局模型的梯度
            t_base.load_state_dict(d)

            # 生成干净梯度
            t_m=NeuTraL(t_num).to(device)
            copy_model(global_model, t_m)
            k1=1000-20
            k2=20
            s_n=random.choices(range(client_data[i][-1][0][0].shape[0]), k=k1)
            s_a=random.choices(range(client_data[i][-1][1][0].shape[0]), k=k2)
            X=torch.cat((client_data[i][-1][0][0][s_n],client_data[i][-1][1][0][s_a]),dim=0)
            y1=torch.cat((torch.zeros(k1).float(),torch.ones(k2).float()),dim=0)
            train_single(t_m,X,y1,lr,weight_decay, device)
            d=calculate_gradient(global_model,t_m)
            t_m.load_state_dict(d)
            global_c_s_info[i][0]=model_cosine_similarity(t_base,t_m)
            
            # 生成污染梯度
            y2=torch.zeros(k1+k2).float()
            copy_model(global_model, t_m)
            train_single(t_m,X,y2,lr,weight_decay, device)
            d=calculate_gradient(global_model,t_m)
            t_m.load_state_dict(d)
            global_c_s_info[i][1]=model_cosine_similarity(t_base,t_m)

        # 把global_c_s_info[-1]更新为当前epoch数
        global_c_s_info[-1]=epoch
        global_c_s_info[-2]=get_th_type3([global_c_s_info[i][0] for i in range(num_cluster)],  #所有cluster中所有客户端
                                         [global_c_s_info[i][1] for i in range(num_cluster)])
        print(global_c_s_info)

    agg_list=[]
    t_base=NeuTraL(t_num).to(device)
    copy_model(model_list[-1], t_base)
    d=calculate_gradient(global_model,t_base)
    t_base.load_state_dict(d)
    c_s_th=global_c_s_info[-2]
    for i in range(len(model_list)-1):
        t_m=NeuTraL(t_num).to(device)
        copy_model(model_list[i], t_m)
        d=calculate_gradient(global_model,t_m)
        t_m.load_state_dict(d)
        if model_cosine_similarity(t_base,t_m)>c_s_th:
            agg_list.append(i)
    agg_list.append(-1)
    print(agg_list,c_s_th)
    return FedAvg([model_list[i] for i in agg_list],data_num=[1]*len(agg_list))

In [None]:
def RFA(model_list,global_model,epoch,cluster_id):
    # RFA算法
    
    def geometric_median_objective(median, points, alphas):
        """Compute geometric median objective."""
        # median是一个点
        # points是点列表
        # alphas是权值列表
        res=[]
        for i in range(len(points)):
            res.append(alphas[i]*(model_distance_square(median,points[i])**0.5))
        return sum(res)
    
    num_m=len(model_list)
    alphas=[1/num_m for i in range(num_m)]  # 各参与方的训练样本数相等

    # 以fedavg的结果作为迭代初始值
    median_d = FedAvg(model_list,data_num=alphas)
    median=NeuTraL(t_num).to(device)
    median.load_state_dict(median_d)

    obj_val = geometric_median_objective(median, model_list, alphas)
    
    # 采用maxiter=4, eps=1e-5, ftol=1e-6
    for i in range(4):
        prev_median, prev_obj_val = median, obj_val
        weights=[alphas[i] / max(1e-5, (model_distance_square(median,model_list[i])**0.5)) for i in range(num_m)]
        
        median_d = FedAvg(model_list,data_num=weights)
        median.load_state_dict(median_d)
        
        obj_val = geometric_median_objective(median, model_list, alphas)
        if abs(prev_obj_val - obj_val) < 1e-6 * obj_val:
            break
    
    d = OrderedDict()
    for k in median.state_dict().keys():
        d[k] = median.state_dict()[k].clone()
    return d

In [None]:
agg_dic={'Krum_3':Krum_3, 'Trimmed_Mean_3':Trimmed_Mean_3, 'AVG':AVG, 
         'FLTrust':FLTrust,'RFA':RFA,'c_s_based_type3':c_s_based_type3}

for k in agg_dic.keys():
    
    if 'c_s_based' in k:
        print(k,'change global_c_s_info')
        global_c_s_info=[[2,-2],[2,-2],[2,-2],0,-1]
    
    local_model_list=[[NeuTraL(t_num) for i in range(len(cluster1))],
                      [NeuTraL(t_num) for i in range(len(cluster2))],
                      [NeuTraL(t_num) for i in range(len(cluster3))],
                     ]
    cluster_global_model_list=[NeuTraL(t_num) for i in range(len(client_data))]
    global_model=NeuTraL(t_num)
    
    print(k,'------------------------------------------------------------------------')
    
    train_hierarchical(local_model_list,cluster_global_model_list,global_model,
                       client_data,lr,weight_decay, device,batch_size,num_epochs,
                       agg_f=agg_dic[k],
                       print_e=50)
    
    
    e_name=k+'-D'+str(dirty_num)
    os.makedirs('./try1_result/'+e_name+'/')
    save_path="./try1_result/"+e_name+"/model.pt"
    torch.save(global_model.state_dict(), save_path)