# The notebook contains
### Code for _Median_ aggregation algorithm, *when gradient updates of benign clients are unknown to adversary*
### Evaluation of all of the attacks (Fang, LIE, and our SOTA AGR-tailored and AGR-agnstic) on Median

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [3]:
from __future__ import print_function
import json
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from femnist_normal_train import *
from femnist_util import *
from adam import Adam
from sgd import SGD
import torchvision.transforms as transforms
import torchvision.datasets as datasets

## Get the FEMNIST dataset; we use [LEAF framework](https://leaf.cmu.edu/)

In [5]:
user_tr_data = []
user_tr_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/train/all_data_%d_niid_0_keep_0_train_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_tr_data.append(obj['user_data'][user]['x'])
        user_tr_labels.append(obj['user_data'][user]['y'])

user_te_data = []
user_te_labels = []

for i in range(34):
    f = '/mnt/nfs/work1/amir/vshejwalkar/leaf/data/femnist/data/test/all_data_%d_niid_0_keep_0_test_9.json'%i
    with open(f, 'r') as myfile:
        data=myfile.read()
    obj = json.loads(data)
    
    for user in obj['users']:
        user_te_data.append(obj['user_data'][user]['x'])
        user_te_labels.append(obj['user_data'][user]['y'])

In [6]:
user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(len(user_tr_data)):
    
    user_tr_data_tensor=torch.from_numpy(np.array(user_tr_data[i])).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(np.array(user_tr_labels[i])).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    
#     print('user %d tr len %d'%(i,len(user_tr_data_tensor)))
print("number of FL clients are ", len(user_tr_data_tensor))

number of FL clients are  244


In [7]:
te_data = np.concatenate(user_te_data, 0)
te_labels = np.concatenate(user_te_labels)
te_len = len(te_labels)

te_data_tensor = torch.from_numpy(te_data[:(te_len//2)]).type(torch.FloatTensor)
te_label_tensor = torch.from_numpy(te_labels[:(te_len//2)]).type(torch.LongTensor)

val_data_tensor = torch.from_numpy(te_data[(te_len//2):]).type(torch.FloatTensor)
val_label_tensor = torch.from_numpy(te_labels[(te_len//2):]).type(torch.LongTensor)

## Model architecture for FEMNIST

In [8]:
class mnist_conv(nn.Module):
    def __init__(self):
        super(mnist_conv, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 62)

    def forward(self, x, noise=torch.Tensor()):
        x = x.reshape(-1, 1, 28, 28)

        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 32 * 7 * 7)  # reshape Variable
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(0)
        m.bias.data.fill_(0)

## Code for (Full-knowledge) Fang attack on Median

### Note that the Fang attacks on Trimmed-mean and Median are the same 

In [9]:
def fang_attack_trmean_partial(all_updates, n_attackers):

    model_re = torch.mean(all_updates, 0)
    model_std = torch.std(all_updates, 0)
    deviation = torch.sign(model_re)
    
    max_vector_low = model_re + 3 * model_std 
    max_vector_hig = model_re + 4 * model_std
    min_vector_low = model_re - 4 * model_std
    min_vector_hig = model_re - 3 * model_std

    max_range = torch.cat((max_vector_low[:,None], max_vector_hig[:,None]), dim=1)
    min_range = torch.cat((min_vector_low[:,None], min_vector_hig[:,None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    return mal_vec

## Evaluation for Full-knolwledge Fang attack on Median

In [12]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()

batch_size = 100
schedule = [5000]

aggregation = 'median'
chkpt = './' + aggregation

at_type='fang'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            proxy_n_attackers = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
        
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
            
        if not (malicious_grads.shape[0]==60):
            print(malicious_grads.shape)
            sys.exit()
            
        agg_grads=torch.median(malicious_grads,dim=0)[0]
        
        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

median: at fang n_at 11 e 0 | val loss 3.9576 val acc 4.8548 best val_acc 4.854819 te_acc 5.202327
median: at fang n_at 12 e 20 | val loss 3.7768 val acc 5.5962 best val_acc 9.171643 te_acc 9.658155
median: at fang n_at 11 e 40 | val loss 3.7763 val acc 21.1594 best val_acc 21.159390 te_acc 23.190383
median: at fang n_at 14 e 60 | val loss 3.5524 val acc 28.9873 best val_acc 28.987335 te_acc 31.849773
median: at fang n_at 8 e 80 | val loss 3.1900 val acc 29.7776 best val_acc 31.564044 te_acc 34.789436
median: at fang n_at 12 e 100 | val loss 2.7007 val acc 35.0391 best val_acc 35.039127 te_acc 39.433175
median: at fang n_at 18 e 120 | val loss 2.4746 val acc 37.1525 best val_acc 37.976215 te_acc 42.099979
median: at fang n_at 6 e 140 | val loss 2.1404 val acc 43.2532 best val_acc 43.253192 te_acc 47.098950
median: at fang n_at 13 e 160 | val loss 1.9909 val acc 46.3705 best val_acc 46.437397 te_acc 50.105540
median: at fang n_at 13 e 180 | val loss 1.8721 val acc 49.2175 best val_acc 5

## Code for our SOTA AGR-tailored attack on Median

In [15]:
def our_attack_median(all_updates, n_attackers, dev_type='sign', threshold=5.0, threshold_diff=1e-5):
    
    model_re = torch.mean(all_updates, 0)
    
    if dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([threshold]).cuda()  # compute_lambda_our(all_updates, model_re, n_attackers)

    threshold_diff = threshold_diff
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]

        loss = torch.norm(agg_grads - model_re)

        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

## Evaluation of our SOTA AGR-tailored attack on Median

In [18]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'median'
chkpt = './' + aggregation

at_type='our-agr'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_median(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type)
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type)
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
                
        if malicious_grads.shape[0] != 60 or not epoch_num: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=torch.median(malicious_grads,dim=0)[0]

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

malicious grads shape  torch.Size([60, 848382])
median: at our-agr n_at 9 e 0 | val loss 3.9986 val acc 4.8986 best val_acc 4.898579 te_acc 5.403110
median: at our-agr n_at 10 e 10 | val loss 3.7912 val acc 12.7471 best val_acc 12.747117 te_acc 13.789642
median: at our-agr n_at 10 e 20 | val loss 3.8561 val acc 15.3161 best val_acc 16.441001 te_acc 18.139930
median: at our-agr n_at 14 e 30 | val loss 3.8118 val acc 20.5416 best val_acc 26.266474 te_acc 29.458402
median: at our-agr n_at 7 e 40 | val loss 3.6176 val acc 28.5832 best val_acc 31.000309 te_acc 34.416186
median: at our-agr n_at 14 e 50 | val loss 3.3989 val acc 32.7327 best val_acc 32.732702 te_acc 36.655684
median: at our-agr n_at 15 e 60 | val loss 3.3995 val acc 29.2962 best val_acc 32.732702 te_acc 36.655684
median: at our-agr n_at 7 e 70 | val loss 3.0007 val acc 30.8150 best val_acc 33.100803 te_acc 36.985173
median: at our-agr n_at 6 e 80 | val loss 2.7274 val acc 32.9644 best val_acc 34.557764 te_acc 38.560544
median

median: at our-agr n_at 13 e 770 | val loss 3.8688 val acc 34.9645 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 15 e 780 | val loss 3.8979 val acc 36.5167 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 17 e 790 | val loss 3.8346 val acc 36.9672 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 6 e 800 | val loss 4.1446 val acc 37.7986 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 10 e 810 | val loss 4.2521 val acc 38.2722 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 12 e 820 | val loss 4.0767 val acc 38.0689 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 12 e 830 | val loss 4.5735 val acc 35.0546 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 22 e 840 | val loss 5.0179 val acc 32.7224 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 19 e 850 | val loss 4.2900 val acc 37.0959 best val_acc 45.279036 te_acc 47.842875
median: at our-agr n_at 14 e 860 | val

## Code for our first SOTA AGR-agnostic attack - Min-max

In [15]:
def our_attack_dist(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    max_distance = torch.max(distances)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        max_d = torch.max(distance)
        
        if max_d <= max_distance:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update

In [17]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'median'
chkpt = './' + aggregation

at_type='min-max'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_median(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type='sign')
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
                
        if malicious_grads.shape[0] != 60: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=torch.median(malicious_grads,dim=0)[0]

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

median: at min-max n_at 12 e 0 | val loss 3.9719 val acc 5.1457 best val_acc 5.145696 te_acc 5.397961
median: at min-max n_at 12 e 20 | val loss 3.8173 val acc 24.8661 best val_acc 24.866145 te_acc 28.132722
median: at min-max n_at 13 e 40 | val loss 3.3582 val acc 34.7791 best val_acc 34.797158 te_acc 39.255560
median: at min-max n_at 14 e 60 | val loss 2.6429 val acc 38.2516 best val_acc 38.251647 te_acc 42.753810
median: at min-max n_at 15 e 80 | val loss 2.4092 val acc 40.8335 best val_acc 41.008546 te_acc 45.636841
median: at min-max n_at 10 e 100 | val loss 2.2684 val acc 42.4732 best val_acc 43.917319 te_acc 48.043657
median: at min-max n_at 6 e 120 | val loss 2.2017 val acc 44.4450 best val_acc 44.445016 te_acc 48.857084
median: at min-max n_at 16 e 140 | val loss 2.1328 val acc 47.1401 best val_acc 47.140136 te_acc 51.346273
median: at min-max n_at 16 e 160 | val loss 2.0475 val acc 48.4169 best val_acc 49.886738 te_acc 53.132722
median: at min-max n_at 12 e 180 | val loss 1.9

## Code for our first SOTA AGR-agnostic attack - Min-sum

In [18]:
def our_attack_score(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    
    lamda = torch.Tensor([50.0]).float().cuda()
    # print(lamda)
    threshold_diff = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    distances = []
    for update in all_updates:
        distance = torch.norm((all_updates - update), dim=1) ** 2
        distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
    
    scores = torch.sum(distances, dim=1)
    min_score = torch.min(scores)
    del distances

    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        distance = torch.norm((all_updates - mal_update), dim=1) ** 2
        score = torch.sum(distance)
        
        if score <= min_score:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2

    # print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    
    return mal_update
    

In [19]:
resume=0
nepochs=1500
gamma=.1
fed_lr=0.001

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()
batch_size = 100
schedule = [5000]

aggregation = 'median'
chkpt = './' + aggregation

at_type='min-sum'
at_fractions = [20]

for at_fraction in at_fractions:
    epoch_num = 0

    fed_model = mnist_conv().cuda()
    fed_model.apply(weights_init)
    optimizer_fed = Adam(fed_model.parameters(), lr=fed_lr)

    best_global_acc=0
    best_global_te_acc=0

    while epoch_num <= nepochs:
        user_grads = []

        round_users = np.random.choice(3400, 60)
        n_attacker = np.sum(round_users < (34*at_fraction))

        at_idx = []
        for i in np.sort(round_users):
            if i < (34*at_fraction):
                at_idx.append(i)
                continue

            inputs = user_tr_data_tensors[i]
            targets = user_tr_label_tensors[i]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            optimizer_fed.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None,:] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]),0)    

        if n_attacker > 0:
            attacker_grads = []
            n_attacker_ = max(1, n_attacker**2//60)
            for i in at_idx:

                inputs = user_tr_data_tensors[i]
                targets = user_tr_label_tensors[i]

                inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

                outputs = fed_model(inputs)
                loss = criterion(outputs, targets)
                optimizer_fed.zero_grad()
                loss.backward(retain_graph=True)

                param_grad=[]
                for param in fed_model.parameters():
                    param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

                attacker_grads=param_grad[None,:] if len(attacker_grads)==0 else torch.cat((attacker_grads,param_grad[None,:]),0)

            mal_updates = []
            if at_type == 'lie':
                mal_update = lie_attack(malicious_grads, z)
            elif at_type == 'fang':
                mal_updates = fang_attack_trmean_partial(attacker_grads, n_attacker)
            elif at_type == 'our-agr':
                mal_update = our_attack_median(attacker_grads, n_attacker_, dev_type='sign', threshold=5.0, threshold_diff=1e-5)
            elif at_type == 'min-max':
                agg_grads = torch.mean(attacker_grads, 0)
                mal_update = our_attack_dist(malicious_grads, agg_grads, n_attacker_, dev_type='sign')
            elif at_type == 'min-sum':
                agg_grads = torch.mean(malicious_grads, 0)
                mal_update = our_attack_score(attacker_grads, agg_grads, n_attacker_, dev_type='sign')
                
        if not len(mal_updates):
            mal_updates = torch.stack([mal_update] * n_attacker)
        malicious_grads = torch.cat((mal_updates, user_grads), 0)
                
        if malicious_grads.shape[0] != 60: 
            print('malicious grads shape ', malicious_grads.shape)
            sys.exit()

        agg_grads=torch.median(malicious_grads,dim=0)[0]

        start_idx=0

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        if epoch_num % 20 == 0 or epoch_num == nepochs-1:
            print('%s: at %s n_at %d e %d | val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        epoch_num+=1

median: at min-sum n_at 14 e 0 | val loss 3.9414 val acc 5.3336 best val_acc 5.333608 te_acc 5.781507
median: at min-sum n_at 12 e 20 | val loss 3.7699 val acc 5.1457 best val_acc 11.259267 te_acc 12.963344
median: at min-sum n_at 13 e 40 | val loss 3.7748 val acc 10.8783 best val_acc 16.394666 te_acc 18.677924
median: at min-sum n_at 10 e 60 | val loss 3.6096 val acc 19.6278 best val_acc 25.738777 te_acc 29.002780
median: at min-sum n_at 10 e 80 | val loss 3.1643 val acc 27.2704 best val_acc 29.339992 te_acc 32.951503
median: at min-sum n_at 11 e 100 | val loss 2.8909 val acc 28.5420 best val_acc 31.981054 te_acc 35.968390
median: at min-sum n_at 13 e 120 | val loss 2.6897 val acc 31.3427 best val_acc 33.489498 te_acc 36.992895
median: at min-sum n_at 11 e 140 | val loss 2.4129 val acc 35.8783 best val_acc 36.735482 te_acc 40.485997
median: at min-sum n_at 13 e 160 | val loss 2.2661 val acc 39.2916 best val_acc 39.567030 te_acc 42.962315
median: at min-sum n_at 16 e 180 | val loss 2.2