### Model Aggregation Scheme Overview

##### Summary
Trained 6 models [FedAvg, FAT] x [Vanilla, Trimmed Mean, Median]. Record test acc and adv acc for each setting. Each model has only been trained for 50 round to save time.

Table 2 

In [None]:
cd to_base_directory

In [2]:
# Import General Libraries
import os
import argparse
import torch
import copy
import pickle
import random
import numpy as np
import pandas as pd

# Import FedEM based Libraries
from utils.utils import *
from utils.constants import *
from utils.args import *
from utils.util_notebooks import *
from run_experiment import *
from models import *

# Import Transfer Attack
from transfer_attacks.Personalized_NN import *
from transfer_attacks.Params import *
from transfer_attacks.Transferer import *
from transfer_attacks.Args import *
from transfer_attacks.TA_utils import *

In [3]:
setting, num_user = "FedAvg", 40

try: # Skip loading if already loaded
    aggregator
except:
    aggregator, clients, args_ = set_args(setting, num_user)

# Load models for FAT and FedAvg
# Change path to where models are stored trained via trimmed mean and median given noniid induced by FAT 
# model is trained by run_model_training/train_model_noniid_sweep.py
save_path_tm_iid = "weights/cifar10/240131_niid_test/tm_iid/"
save_path_tm_niid = "weights/cifar10/240131_niid_test/tm_niid/"
save_path_med_iid = "weights/cifar10/240131_niid_test/med_iid/"
save_path_med_niid = "weights/cifar10/240131_niid_test/med_niid/"

model_tm_iid = copy.deepcopy(import_model_weights(num_user, setting, save_path_tm_iid, aggregator, args_)[0])
model_tm_niid = copy.deepcopy(import_model_weights(num_user, setting, save_path_tm_niid, aggregator, args_)[0])

model_med_iid = copy.deepcopy(import_model_weights(num_user, setting, save_path_med_iid, aggregator, args_)[0])
model_med_niid = copy.deepcopy(import_model_weights(num_user, setting, save_path_med_niid, aggregator, args_)[0])

==> Clients initialization..
===> Building data iterators..


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 159.67it/s]


===> Initializing clients..


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:15<00:00,  5.13it/s]


==> Test Clients initialization..
===> Building data iterators..


0it [00:00, ?it/s]


===> Initializing clients..


0it [00:00, ?it/s]


In [13]:
metrics = ['orig_acc_transfers','orig_similarities','adv_acc_transfers','adv_similarities_target',
            'adv_similarities_untarget','adv_target','adv_miss']

# Perform transfer attack from one client to another and record stats
models_test = [model_tm_iid, model_tm_niid,model_med_iid,model_med_niid]

# Run Measurements for both targetted and untargeted analysis
victim_idxs = range(len(models_test))
custom_batch_size = 1000
eps = 4.5

num_cid = 5
c_id_list = range(num_cid)

acc_store = np.zeros([num_cid, len(models_test)])
adv_store = np.zeros([num_cid, len(models_test)])


for t in c_id_list:
    print("Running cid trial", t)
    logs_adv = generate_logs_adv(len(models_test))
    dataloader = load_client_data(clients = clients, c_id = t, mode = 'test') # or test/train
    cross_attack(logs_adv, victim_idxs, dataloader, models_test, custom_batch_size, eps)
    
    adv_miss = get_metric_list("adv_miss", logs_adv, victim_idxs)
    acc = get_metric_list("orig_acc_transfers", logs_adv, victim_idxs)
    adv_target = get_metric_list("adv_target", logs_adv, victim_idxs)
    
    for m in range(len(models_test)):
        acc_store[t,m] = acc[m,m]
        adv_store[t,m] = adv_miss[m,m]




Running cid trial 0
	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3
Running cid trial 1
	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3
Running cid trial 2
	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3
Running cid trial 3
	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3
Running cid trial 4
	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3


In [15]:
adv_store

array([[0.50, 0.31, 0.62, 0.41],
       [0.37, 0.27, 0.52, 0.32],
       [0.48, 0.30, 0.53, 0.37],
       [0.35, 0.27, 0.51, 0.32],
       [0.39, 0.25, 0.49, 0.31]])

In [16]:
# Print along diagonal and plot "robustness against white-box attacks (self)"
adv_list = []
acc_list = []
adv_std_list = []
acc_std_list = []
for i in range(len(models_test)):
    adv_list += [np.mean(adv_store[:,i])]
    acc_list += [np.mean(acc_store[:,i])]
    
    adv_std_list += [np.std(adv_store[:,i])]
    acc_std_list += [np.std(acc_store[:,i])]



In [19]:
from prettytable import PrettyTable

# Your lists
labels = ['TM_iid', 'TM_niid', 'Med_iid', 'Med_niid']
x_values = acc_list
y_values = adv_list
xstd_values = acc_std_list
ystd_values = adv_std_list

# Creating the table
table = PrettyTable()
table.field_names = ["Setting", "Test Acc", "Test Acc STD", "Adv Acc", "Adv Acc STD"]

# Adding rows
for label, x, x1, y, y1 in zip(labels, x_values, xstd_values, y_values,ystd_values):
    table.add_row([label, x, x1, y, y1])

# Printing the table
print(table)


+----------+--------------------+----------------------+--------------------+----------------------+
| Setting  |      Test Acc      |     Test Acc STD     |      Adv Acc       |     Adv Acc STD      |
+----------+--------------------+----------------------+--------------------+----------------------+
|  TM_iid  | 0.7841004610061646 | 0.022007606756669976 | 0.4179572343826294 | 0.05990667874994324  |
| TM_niid  | 0.8064494013786316 | 0.019894864418346807 | 0.2787453681230545 | 0.022956325958320107 |
| Med_iid  | 0.8043594837188721 | 0.03343577385828096  | 0.5350706219673157 | 0.046885788682681874 |
| Med_niid | 0.8113937139511108 | 0.02365371487288273  | 0.3465024471282959 | 0.038851698369482476 |
+----------+--------------------+----------------------+--------------------+----------------------+


In [7]:
adv_miss

array([[0.42, 0.57, 0.57, 0.57],
       [0.62, 0.28, 0.58, 0.44],
       [0.68, 0.66, 0.53, 0.64],
       [0.71, 0.66, 0.68, 0.35]])