# The notebook contains
### Code for _Trimmed-mean_ aggregation algorithm, *when gradient updates of benign clients are unknown to adversary*
### Evaluation of all of the attacks (Fang, LIE, and our SOTA AGR-tailored and AGR-agnstic) on Trimmed-mean

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
from __future__ import print_function
import json
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from femnist_normal_train import *
from femnist_util import *
from adam import Adam
from sgd import SGD
import torchvision.transforms as transforms
import torchvision.datasets as datasets

## Get the FEMNIST dataset; we use [LEAF framework](https://leaf.cmu.edu/)

In [3]:
user_tr_data = []
user_tr_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/train/all_data_%d_niid_0_keep_0_train_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_tr_data.append(obj['user_data'][user]['x'])
        user_tr_labels.append(obj['user_data'][user]['y'])

user_te_data = []
user_te_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/test/all_data_%d_niid_0_keep_0_test_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_te_data.append(obj['user_data'][user]['x'])
        user_te_labels.append(obj['user_data'][user]['y'])

In [4]:
user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(len(user_tr_data)):
    
    user_tr_data_tensor=torch.from_numpy(np.array(user_tr_data[i])).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(np.array(user_tr_labels[i])).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    
#     print('user %d tr len %d'%(i,len(user_tr_data_tensor)))
print("number of FL clients are ", len(user_tr_data_tensor))

number of FL clients are  244


In [5]:
te_data = np.concatenate(user_te_data, 0)
te_labels = np.concatenate(user_te_labels)
te_len = len(te_labels)

te_data_tensor = torch.from_numpy(te_data[:(te_len//2)]).type(torch.FloatTensor)
te_label_tensor = torch.from_numpy(te_labels[:(te_len//2)]).type(torch.LongTensor)

val_data_tensor = torch.from_numpy(te_data[(te_len//2):]).type(torch.FloatTensor)
val_label_tensor = torch.from_numpy(te_labels[(te_len//2):]).type(torch.LongTensor)

## Model architecture for FEMNIST

In [6]:
class mnist_conv(nn.Module):
    def __init__(self):
        super(mnist_conv, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 62)

    def forward(self, x, noise=torch.Tensor()):
        x = x.reshape(-1, 1, 28, 28)

        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 32 * 7 * 7)  # reshape Variable
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(0)
        m.bias.data.fill_(0)

In [9]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

## Code for (Full-knowledge) Fang attack on Trimmed-mean

### Note that the Fang attacks on Trimmed-mean and Median are the same 

In [7]:
def fang_attack_trmean_partial(all_updates, n_attackers):

    model_re = torch.mean(all_updates, 0)
    model_std = torch.std(all_updates, 0)
    deviation = torch.sign(model_re)
    
    max_vector_low = model_re + 3 * model_std 
    max_vector_hig = model_re + 4 * model_std
    min_vector_low = model_re - 4 * model_std
    min_vector_hig = model_re - 3 * model_std

    max_range = torch.cat((max_vector_low[:,None], max_vector_hig[:,None]), dim=1)
    min_range = torch.cat((min_vector_low[:,None], min_vector_hig[:,None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    return mal_vec

## Evaluation for Full-knolwledge Fang attack on Trimmed-mean

In [10]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()

batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='fang'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            proxy_n_attackers = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
        
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
            
        if not (malicious_grads.shape[0]==60):
            print(malicious_grads.shape)
            sys.exit()
            
        agg_grads=tr_mean(malicious_grads, n_attacker)
        
        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729138878/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)


trmean: at fang n_at 14 e 0 | val loss 3.9637 val acc 4.7261 best val_acc 4.726112 te_acc 5.251236
trmean: at fang n_at 11 e 20 | val loss 3.7232 val acc 9.6633 best val_acc 16.675247 te_acc 18.317545
trmean: at fang n_at 12 e 40 | val loss 3.6013 val acc 25.7619 best val_acc 26.714374 te_acc 29.553645
trmean: at fang n_at 11 e 60 | val loss 3.2126 val acc 32.2204 best val_acc 32.601421 te_acc 36.732908
trmean: at fang n_at 12 e 80 | val loss 2.7087 val acc 36.2464 best val_acc 36.442030 te_acc 41.047158
trmean: at fang n_at 17 e 100 | val loss 2.2837 val acc 41.3509 best val_acc 41.350906 te_acc 45.696046
trmean: at fang n_at 19 e 120 | val loss 2.0236 val acc 45.8145 best val_acc 46.373044 te_acc 49.822385
trmean: at fang n_at 16 e 140 | val loss 1.8586 val acc 50.5354 best val_acc 50.535420 te_acc 53.554881
trmean: at fang n_at 14 e 160 | val loss 1.7640 val acc 51.5059 best val_acc 53.058072 te_acc 55.683690
trmean: at fang n_at 13 e 180 | val loss 1.6202 val acc 55.9051 best val_a

## Code for our SOTA AGR-tailored attack on Trimmed-mean

In [11]:
def our_attack_trmean(all_updates, n_attackers, dev_type='sign', threshold=5.0, threshold_diff=1e-5):
    
    model_re = torch.mean(all_updates, 0)
    
    if dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([threshold]).cuda()  # compute_lambda_our(all_updates, model_re, n_attackers)

    threshold_diff = threshold_diff
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]

        loss = torch.norm(agg_grads - model_re)

        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

## Evaluation of our SOTA AGR-tailored attack on Trimmed-mean

In [13]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='our-agr'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_trmean(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type)
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)

        if malicious_grads.shape[0] != 60: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=tr_mean(malicious_grads, n_attacker)

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

trmean: at our-agr n_at 14 e 0 | val loss 3.9537 val acc 4.7261 best val_acc 4.726112 te_acc 5.243513
trmean: at our-agr n_at 11 e 20 | val loss 3.8830 val acc 7.1252 best val_acc 11.405993 te_acc 12.157640
trmean: at our-agr n_at 8 e 40 | val loss 3.7386 val acc 26.4029 best val_acc 27.169996 te_acc 30.655375
trmean: at our-agr n_at 12 e 60 | val loss 3.4513 val acc 32.2925 best val_acc 32.838241 te_acc 36.918245
trmean: at our-agr n_at 14 e 80 | val loss 3.2786 val acc 35.0160 best val_acc 35.015960 te_acc 39.482084
trmean: at our-agr n_at 10 e 100 | val loss 2.7051 val acc 37.8527 best val_acc 37.852657 te_acc 42.540157
trmean: at our-agr n_at 14 e 120 | val loss 2.5775 val acc 38.3855 best val_acc 40.174526 te_acc 44.849156
trmean: at our-agr n_at 14 e 140 | val loss 2.2135 val acc 43.3227 best val_acc 43.322694 te_acc 47.575165
trmean: at our-agr n_at 7 e 160 | val loss 2.0701 val acc 46.0307 best val_acc 46.409082 te_acc 50.136429
trmean: at our-agr n_at 10 e 180 | val loss 2.354

## Code for our first SOTA AGR-agnostic attack - Min-max

In [14]:
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

In [16]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='min-max'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_trmean(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type='sign')
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
                
        if malicious_grads.shape[0] != 60: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=tr_mean(malicious_grads, n_attacker)

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

trmean: at min-max n_at 11 e 0 | val loss 3.9353 val acc 4.7416 best val_acc 4.741557 te_acc 5.230643
trmean: at min-max n_at 10 e 20 | val loss 3.8028 val acc 21.0152 best val_acc 21.071870 te_acc 23.434926
trmean: at min-max n_at 12 e 40 | val loss 3.5589 val acc 32.0119 best val_acc 32.444399 te_acc 36.859040
trmean: at min-max n_at 11 e 60 | val loss 3.0166 val acc 35.3969 best val_acc 35.860276 te_acc 40.496293
trmean: at min-max n_at 12 e 80 | val loss 2.4979 val acc 39.1166 best val_acc 39.116557 te_acc 43.541495
trmean: at min-max n_at 13 e 100 | val loss 2.2924 val acc 40.4500 best val_acc 41.613468 te_acc 46.275227
trmean: at min-max n_at 8 e 120 | val loss 2.3111 val acc 42.2261 best val_acc 44.728171 te_acc 48.591948
trmean: at min-max n_at 13 e 140 | val loss 2.2046 val acc 44.1078 best val_acc 46.331857 te_acc 49.791495
trmean: at min-max n_at 16 e 160 | val loss 2.0067 val acc 49.4826 best val_acc 49.482599 te_acc 52.231775
trmean: at min-max n_at 12 e 180 | val loss 2.0

## Code for our first SOTA AGR-agnostic attack - Min-sum

In [17]:
def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
    

In [19]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'trmean'
chkpt = './' + aggregation

at_type='min-max'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_trmean(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type='sign')
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
                
        if malicious_grads.shape[0] != 60: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=tr_mean(malicious_grads, n_attacker)

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

trmean: at min-max n_at 9 e 0 | val loss 3.9717 val acc 5.2898 best val_acc 5.289848 te_acc 5.974568
trmean: at min-max n_at 11 e 20 | val loss 3.8340 val acc 14.8425 best val_acc 16.080622 te_acc 18.026668
trmean: at min-max n_at 18 e 40 | val loss 3.6740 val acc 28.0272 best val_acc 28.740218 te_acc 32.390342
trmean: at min-max n_at 12 e 60 | val loss 3.4031 val acc 30.4597 best val_acc 30.660523 te_acc 34.367278
trmean: at min-max n_at 8 e 80 | val loss 3.1673 val acc 33.8833 best val_acc 34.251442 te_acc 38.231054
trmean: at min-max n_at 11 e 100 | val loss 2.6509 val acc 38.8617 best val_acc 38.861717 te_acc 43.222302
trmean: at min-max n_at 13 e 120 | val loss 2.6429 val acc 40.3187 best val_acc 40.439662 te_acc 44.882619
trmean: at min-max n_at 11 e 140 | val loss 2.2946 val acc 42.3497 best val_acc 42.349671 te_acc 46.970243
trmean: at min-max n_at 17 e 160 | val loss 2.1820 val acc 43.2609 best val_acc 43.608423 te_acc 48.316516
trmean: at min-max n_at 8 e 180 | val loss 2.126