# The notebook contains
### Code for _Bulyan_ aggregation algorithm
### Evaluation of all the attacks (Fang, LIE, and our AGR-agnstic) on Multi-krum, except our AGR-tailored attack on Bulyan

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math, json
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from femnist_normal_train import *
from femnist_util import *
from adam import Adam
from sgd import SGD
import torchvision.transforms as transforms
import torchvision.datasets as datasets

## Get the FEMNIST dataset; we use [LEAF framework](https://leaf.cmu.edu/)

In [3]:
user_tr_data = []
user_tr_labels = []

for i in range(34):
    f = './leaf/data/femnist/data/train/all_data_%d_niid_05_keep_0_train_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_tr_data.append(obj['user_data'][user]['x'])
        user_tr_labels.append(obj['user_data'][user]['y'])

user_te_data = []
user_te_labels = []

for i in range(34):
    f = './leaf/data/femnist/data/test/all_data_%d_niid_05_keep_0_test_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_te_data.append(obj['user_data'][user]['x'])
        user_te_labels.append(obj['user_data'][user]['y'])

In [4]:
user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(len(user_tr_data)):
    
    user_tr_data_tensor=torch.from_numpy(np.array(user_tr_data[i])).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(np.array(user_tr_labels[i])).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

user 0 tr len 159
user 1 tr len 158
user 2 tr len 162
user 3 tr len 161
user 4 tr len 109
user 5 tr len 152
user 6 tr len 158
user 7 tr len 153
user 8 tr len 162
user 9 tr len 135
user 10 tr len 156
user 11 tr len 151
user 12 tr len 83
user 13 tr len 162
user 14 tr len 153
user 15 tr len 55
user 16 tr len 163
user 17 tr len 130
user 18 tr len 162
user 19 tr len 156
user 20 tr len 148
user 21 tr len 161
user 22 tr len 146
user 23 tr len 121
user 24 tr len 153
user 25 tr len 133
user 26 tr len 33
user 27 tr len 135
user 28 tr len 159
user 29 tr len 134
user 30 tr len 162
user 31 tr len 157
user 32 tr len 25
user 33 tr len 152
user 34 tr len 128
user 35 tr len 161
user 36 tr len 165
user 37 tr len 163
user 38 tr len 7
user 39 tr len 133
user 40 tr len 156
user 41 tr len 156
user 42 tr len 155
user 43 tr len 156
user 44 tr len 22
user 45 tr len 158
user 46 tr len 163
user 47 tr len 162
user 48 tr len 155
user 49 tr len 135
user 50 tr len 160
user 51 tr len 153
user 52 tr len 156
user 53 tr

In [5]:
te_data = np.concatenate(user_te_data, 0)
te_labels = np.concatenate(user_te_labels)
te_len = len(te_labels)

te_data_tensor = torch.from_numpy(te_data[:(te_len//2)]).type(torch.FloatTensor)
te_label_tensor = torch.from_numpy(te_labels[:(te_len//2)]).type(torch.LongTensor)

val_data_tensor = torch.from_numpy(te_data[(te_len//2):]).type(torch.FloatTensor)
val_label_tensor = torch.from_numpy(te_labels[(te_len//2):]).type(torch.LongTensor)

## Code for Bulyan aggregation algorithm

In [6]:
def bulyan(all_updates, n_attackers):
    nusers = all_updates.shape[0]
    bulyan_cluster = []
    candidate_indices = []
    remaining_updates = all_updates
    all_indices = np.arange(len(all_updates))

    while len(bulyan_cluster) < (nusers - 2 * n_attackers):
        distances = []
        for update in remaining_updates:
            distance = torch.norm((remaining_updates - update), dim=1) ** 2
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

        distances = torch.sort(distances, dim=1)[0]

        scores = torch.sum(distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1)
        indices = torch.argsort(scores)[:len(remaining_updates) - 2 - n_attackers]

        candidate_indices.append(all_indices[indices[0].cpu().numpy()])
        all_indices = np.delete(all_indices, indices[0].cpu().numpy())
        bulyan_cluster = remaining_updates[indices[0]][None, :] if not len(bulyan_cluster) else torch.cat((bulyan_cluster, remaining_updates[indices[0]][None, :]), 0)
        remaining_updates = torch.cat((remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0)

    # print('dim of bulyan cluster ', bulyan_cluster.shape)

    n, d = bulyan_cluster.shape
    param_med = torch.median(bulyan_cluster, dim=0)[0]
    sort_idx = torch.argsort(torch.abs(bulyan_cluster - param_med), dim=0)
    sorted_params = bulyan_cluster[sort_idx, torch.arange(d)[None, :]]

    return torch.mean(sorted_params[:n - 2 * n_attackers], dim=0), np.array(candidate_indices)

## Code for Fang attack on Bulyan
### Fang attacks on MKrum and Bulyan are the same

In [8]:
def compute_lambda(all_updates, model_re, n_attackers):

    distances = []
    n_benign, d = all_updates.shape
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1)
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)

    distances[distances == 0] = 10000
    distances = torch.sort(distances, dim=1)[0]
    scores = torch.sum(distances[:, :n_benign - 2 - n_attackers], dim=1)
    min_score = torch.min(scores)
    term_1 = min_score / ((n_benign - n_attackers - 1) * torch.sqrt(torch.Tensor([d]))[0])
    max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0])

    return (term_1 + max_wre_dist)


def get_malicious_updates_fang(all_updates, model_re, deviation, n_attackers):

    lamda = compute_lambda(all_updates, model_re, n_attackers)

    threshold = 1e-5
    mal_update = []

    while lamda > threshold:
        mal_update = (-lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        # print(mal_updates.shape, n_attackers)
        agg_grads, krum_candidate = multi_krum(mal_updates, n_attackers, multi_k=False)
        if krum_candidate < n_attackers:
            # print('successful lamda is ', lamda)
            return mal_update
        else:
            mal_update = []

        lamda *= 0.5

    if not len(mal_update):
        mal_update = (model_re - lamda * deviation)
        
    return mal_update

In [15]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.float if torch.cuda.is_available() else torch.float#torch.cuda.is_available()
batch_size = 100
schedule = [2000]

aggregation = 'bulyan'
at_type = 'fang'
chkpt = './' + aggregation
epoch_num = 0

at_fractions = [20]

for at_fraction in at_fractions:

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    print('==> Initializing global model')
    epoch_num = 0
    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = max(2, np.sum(round_users < (34*at_fraction)))
        if n_attacker > 14:
            n_attacker = 14

        # print('n_attackers is ', n_attackers)
        attacker_count = 0
        for i in round_users:
            if i < (34*at_fraction) and attacker_count < n_attacker:
                attacker_count += 1
                continue

            inputs = user_tr_data_tensors[i]
            print( inputs)
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if n_attacker > 0:
            if at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                mal_update = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            
            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)
            
        if epoch_num == 0: print('malicious grads shape ', malicious_grads.shape)

        if aggregation == 'mean':
            agg_grads=torch.mean(malicious_grads,dim=0)
            
        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k, verbose=True)
        
        elif aggregation == 'bulyan':
            agg_grads, krum_candidate = bulyan(malicious_grads, n_attacker)
            
        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 10 == 0:
            print('%s: at %s at_frac %.1f n_at %d n_mal_sel %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, at_fraction, n_attacker, np.sum(krum_candidate < n_attacker), epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

==> Initializing global model


IndexError: list index out of range

## Our first AGR-agnostic attack - Min-Max

In [18]:
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

In [19]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [2000]

aggregation = 'bulyan'
at_type = 'min-max'
chkpt = './' + aggregation
epoch_num = 0

at_fractions = [20]

for at_fraction in at_fractions:

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    print('==> Initializing global model')
    epoch_num = 0
    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = max(2, np.sum(round_users < (34*at_fraction)))
        if n_attacker > 14: n_attacker = 14

        attacker_count = 0
        for i in round_users:
            if i < (34*at_fraction) and attacker_count < n_attacker:
                attacker_count += 1
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if n_attacker > 0:
            if at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                mal_update = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type='sign')
                
            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)
            
        if epoch_num == 0: print('malicious grads shape ', malicious_grads.shape)

        if aggregation == 'mean':
            agg_grads=torch.mean(malicious_grads,dim=0)
            
        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k, verbose=True)
        
        elif aggregation == 'bulyan':
            agg_grads, krum_candidate = bulyan(malicious_grads, n_attacker)
            
        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 10 == 0:
            print('%s: at %s at_frac %.1f n_at %d n_mal_sel %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, at_fraction, n_attacker, np.sum(krum_candidate < n_attacker), epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

==> Initializing global model
malicious grads shape  torch.Size([60, 848382])
bulyan: at min-max at_frac 20.0 n_at 11 n_mal_sel 4 e 0 fed_model val loss 3.9545 val acc 3.9642 best val_acc 3.964168 te_acc 4.437809
bulyan: at min-max at_frac 20.0 n_at 8 n_mal_sel 3 e 10 fed_model val loss 3.8240 val acc 8.8962 best val_acc 10.502471 te_acc 11.395696
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 5 e 20 fed_model val loss 3.8531 val acc 13.1770 best val_acc 17.645696 te_acc 19.666392
bulyan: at min-max at_frac 20.0 n_at 10 n_mal_sel 4 e 30 fed_model val loss 3.8279 val acc 26.9924 best val_acc 26.992381 te_acc 30.246087
bulyan: at min-max at_frac 20.0 n_at 11 n_mal_sel 3 e 40 fed_model val loss 3.6456 val acc 29.0157 best val_acc 30.032434 te_acc 33.991454
bulyan: at min-max at_frac 20.0 n_at 12 n_mal_sel 3 e 50 fed_model val loss 3.2941 val acc 34.1227 best val_acc 34.122735 te_acc 38.279963
bulyan: at min-max at_frac 20.0 n_at 11 n_mal_sel 3 e 60 fed_model val loss 3.1105 val acc 31.

bulyan: at min-max at_frac 20.0 n_at 13 n_mal_sel 5 e 590 fed_model val loss 1.9437 val acc 55.0376 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 12 n_mal_sel 5 e 600 fed_model val loss 1.9330 val acc 56.6928 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 4 e 610 fed_model val loss 1.9031 val acc 56.3916 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 12 n_mal_sel 4 e 620 fed_model val loss 1.9769 val acc 55.6219 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 13 n_mal_sel 5 e 630 fed_model val loss 2.1460 val acc 55.3388 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 5 e 640 fed_model val loss 2.0416 val acc 56.7674 best val_acc 57.745572 te_acc 60.553954
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 4 e 650 fed_model val loss 1.9592 val acc 56.4405 best val_acc 58.157434 te_acc 61.140857
bulyan: at min-max a

bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 5 e 1180 fed_model val loss 1.6733 val acc 61.9491 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 13 n_mal_sel 4 e 1190 fed_model val loss 1.6888 val acc 62.0470 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 4 e 1200 fed_model val loss 1.7707 val acc 59.5629 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 9 n_mal_sel 3 e 1210 fed_model val loss 1.6172 val acc 62.2941 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 14 n_mal_sel 4 e 1220 fed_model val loss 1.7081 val acc 60.8757 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 9 n_mal_sel 4 e 1230 fed_model val loss 1.6812 val acc 61.8565 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-max at_frac 20.0 n_at 13 n_mal_sel 4 e 1240 fed_model val loss 1.6066 val acc 62.0289 best val_acc 62.844934 te_acc 65.272343
bulyan: at min-

## Our second AGR-agnostic attack - Min-Sum

In [20]:
def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    return mal_update

In [21]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [2000]

aggregation = 'bulyan'
at_type = 'min-sum'
chkpt = './' + aggregation
epoch_num = 0

at_fractions = [20]

for at_fraction in at_fractions:

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    print('==> Initializing global model')
    epoch_num = 0
    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = max(2, np.sum(round_users < (34*at_fraction)))
        if n_attacker > 14: n_attacker = 14

        attacker_count = 0
        for i in round_users:
            if i < (34*at_fraction) and attacker_count < n_attacker:
                attacker_count += 1
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        malicious_grads = user_grads

        if n_attacker > 0:
            if at_type == 'fang':
                agg_grads = torch.mean(malicious_grads, 0)
                deviation = torch.sign(agg_grads)
                mal_update = get_malicious_updates_fang(malicious_grads, agg_grads, deviation, n_attacker)
            elif at_type == 'min-max':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(malicious_grads, agg_grads, n_attacker, dev_type='sign')
                
            mal_updates = torch.stack([mal_update] * n_attacker)
            malicious_grads = torch.cat((mal_updates, user_grads), 0)
            
        if epoch_num == 0: print('malicious grads shape ', malicious_grads.shape)

        if aggregation == 'mean':
            agg_grads=torch.mean(malicious_grads,dim=0)
            
        elif aggregation=='krum' or aggregation=='mkrum':
            multi_k = True if aggregation == 'mkrum' else False
            if epoch_num == 0: print('multi krum is ', multi_k)
            agg_grads, krum_candidate = multi_krum(malicious_grads, n_attacker, multi_k=multi_k, verbose=True)

        elif aggregation == 'bulyan':
            agg_grads, krum_candidate = bulyan(malicious_grads, n_attacker)
            
        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 10 == 0:
            print('%s: at %s at_frac %.1f n_at %d n_mal_sel %d e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, at_fraction, n_attacker, np.sum(krum_candidate < n_attacker), epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

==> Initializing global model
malicious grads shape  torch.Size([60, 848382])
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 11 e 0 fed_model val loss 4.0064 val acc 5.5086 best val_acc 5.508649 te_acc 5.982290
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 6 e 10 fed_model val loss 3.8430 val acc 5.9720 best val_acc 8.960564 te_acc 9.552615
bulyan: at min-sum at_frac 20.0 n_at 9 n_mal_sel 9 e 20 fed_model val loss 3.8925 val acc 8.7083 best val_acc 14.698311 te_acc 15.856672
bulyan: at min-sum at_frac 20.0 n_at 8 n_mal_sel 8 e 30 fed_model val loss 3.8842 val acc 17.4629 best val_acc 20.129736 te_acc 22.062912
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 10 e 40 fed_model val loss 3.7043 val acc 27.1983 best val_acc 29.787891 te_acc 33.332475
bulyan: at min-sum at_frac 20.0 n_at 11 n_mal_sel 11 e 50 fed_model val loss 3.4163 val acc 32.4701 best val_acc 32.470140 te_acc 36.503810
bulyan: at min-sum at_frac 20.0 n_at 9 n_mal_sel 9 e 60 fed_model val loss 3.2254 val acc 30.77

bulyan: at min-sum at_frac 20.0 n_at 11 n_mal_sel 11 e 580 fed_model val loss 2.8383 val acc 35.9684 best val_acc 36.992895 te_acc 41.011120
bulyan: at min-sum at_frac 20.0 n_at 10 n_mal_sel 10 e 590 fed_model val loss 2.7175 val acc 36.5218 best val_acc 36.992895 te_acc 41.011120
bulyan: at min-sum at_frac 20.0 n_at 10 n_mal_sel 10 e 600 fed_model val loss 2.6440 val acc 37.1293 best val_acc 37.636429 te_acc 41.675247
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 610 fed_model val loss 2.6275 val acc 37.0032 best val_acc 38.104922 te_acc 42.053645
bulyan: at min-sum at_frac 20.0 n_at 11 n_mal_sel 11 e 620 fed_model val loss 2.6050 val acc 37.4382 best val_acc 38.321149 te_acc 42.365115
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 630 fed_model val loss 2.5988 val acc 38.4164 best val_acc 39.013591 te_acc 43.134782
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 640 fed_model val loss 2.5131 val acc 39.2067 best val_acc 39.706034 te_acc 43.518328
bulyan: at mi

bulyan: at min-sum at_frac 20.0 n_at 11 n_mal_sel 11 e 1170 fed_model val loss 2.6820 val acc 37.7265 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 1180 fed_model val loss 2.6707 val acc 38.1924 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 1190 fed_model val loss 2.6051 val acc 38.8900 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 7 n_mal_sel 7 e 1200 fed_model val loss 2.7121 val acc 38.9827 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 14 n_mal_sel 14 e 1210 fed_model val loss 2.6551 val acc 38.8488 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 9 n_mal_sel 9 e 1220 fed_model val loss 2.6871 val acc 38.4421 best val_acc 41.363777 te_acc 44.964992
bulyan: at min-sum at_frac 20.0 n_at 12 n_mal_sel 12 e 1230 fed_model val loss 2.7384 val acc 37.0547 best val_acc 41.363777 te_acc 44.964992
bulyan: at