### Model Aggregation Scheme Overview

Jan 14 2024
TJ Kim

##### Summary
Trained 6 models [FedAvg, FAT] x [Vanilla, Trimmed Mean, Median]. Record test acc and adv acc for each setting. Each model has only been trained for 50 round to save time.

Goal is to see if robust aggregation schemes harm performances, as we want to discredit the effectiveness of robust aggregation schemes against our proposed attack.

In [1]:
cd /home/ubuntu/fedatk_unl_tj/

/home/ubuntu/fedatk_unl_tj


In [2]:
# Import General Libraries
import os
import argparse
import torch
import copy
import pickle
import random
import numpy as np
import pandas as pd

# Import FedEM based Libraries
from utils.utils import *
from utils.constants import *
from utils.args import *
from utils.util_notebooks import *
from run_experiment import *
from models import *

# Import Transfer Attack
from transfer_attacks.Personalized_NN import *
from transfer_attacks.Params import *
from transfer_attacks.Transferer import *
from transfer_attacks.Args import *
from transfer_attacks.TA_utils import *

In [3]:
setting, num_user = "FedAvg", 40

try: # Skip loading if already loaded
    aggregator
except:
    aggregator, clients, args_ = set_args(setting, num_user)

# Load models for FAT and FedAvg
save_path_FAT = 'weights/cifar10/240115_robust_tests/FAT/'
save_path_FedAvg = 'weights/cifar10/240111_robust_tests/fedavg/'

save_path_FAT_tm = 'weights/cifar10/240115_robust_tests/FAT_tm/'
save_path_FedAvg_tm = 'weights/cifar10/240111_robust_tests/fedavg_tm/'

save_path_FAT_med = 'weights/cifar10/240115_robust_tests/FAT_md/'
save_path_FedAvg_med = 'weights/cifar10/240111_robust_tests/fedavg_md/'

# save_paths = [save_path_FedAvg, save_path_FAT,
#               save_path_FedAvg_tm, save_path_FAT_tm,
#               save_path_FedAvg_med, save_path_FAT_med]

# models_test = []

# for m_i in range(len(save_paths)):
#     aggregator, clients, args_ = set_args(setting, num_user)
#     models_test += [import_model_weights(num_user, setting, save_paths[m_i], aggregator, args_)[0]]

model_FAT = copy.deepcopy(import_model_weights(num_user, setting, save_path_FAT, aggregator, args_)[0])
model_Fedavg = copy.deepcopy(import_model_weights(num_user, setting, save_path_FedAvg, aggregator, args_)[0])

model_FAT_tm = copy.deepcopy(import_model_weights(num_user, setting, save_path_FAT_tm, aggregator, args_)[0])
model_Fedavg_tm = copy.deepcopy(import_model_weights(num_user, setting, save_path_FedAvg_tm, aggregator, args_)[0])

model_FAT_med = copy.deepcopy(import_model_weights(num_user, setting, save_path_FAT_med, aggregator, args_)[0])
model_Fedavg_med = copy.deepcopy(import_model_weights(num_user, setting, save_path_FedAvg_med, aggregator, args_)[0])

==> Clients initialization..
===> Building data iterators..


100%|██████████| 80/80 [00:00<00:00, 159.18it/s]


===> Initializing clients..


100%|██████████| 80/80 [00:14<00:00,  5.36it/s]


==> Test Clients initialization..
===> Building data iterators..


0it [00:00, ?it/s]


===> Initializing clients..


0it [00:00, ?it/s]


In [4]:
# Perform transfer attack from one client to another and record stats
models_test = [model_Fedavg, model_FAT, model_Fedavg_tm, model_FAT_tm, model_Fedavg_med, model_FAT_med]
# models_test = [model_FAT]

# Run Measurements for both targetted and untargeted analysis
victim_idxs = range(len(models_test))
logs_adv = generate_logs_adv(len(models_test))
custom_batch_size = 1000
eps = 4.5

dataloader = load_client_data(clients = clients, c_id = 0, mode = 'all') # or test/train
cross_attack(logs_adv, victim_idxs, dataloader, models_test, custom_batch_size, eps)

	 Adv idx: 0
	 Adv idx: 1
	 Adv idx: 2
	 Adv idx: 3
	 Adv idx: 4
	 Adv idx: 5


In [5]:
metrics = ['orig_acc_transfers','orig_similarities','adv_acc_transfers','adv_similarities_target',
            'adv_similarities_untarget','adv_target','adv_miss']

adv_miss = get_metric_list("adv_miss", logs_adv, victim_idxs)
acc = get_metric_list("orig_acc_transfers", logs_adv, victim_idxs)
adv_target = get_metric_list("adv_target", logs_adv, victim_idxs)

# Print along diagonal and plot "robustness against white-box attacks (self)"
adv_list = []
acc_list = []
for i in range(adv_miss.shape[0]):
    adv_list += [adv_miss[i,i]]
    acc_list += [acc[i,i]]



In [6]:
from prettytable import PrettyTable

# Your lists
labels = ['FedAvg', 'FAT', 'FedAvg TM', 'FAT TM','FedAvg Med', 'FAT Med']
x_values = acc_list
y_values = adv_list

# Creating the table
table = PrettyTable()
table.field_names = ["Setting", "Test Acc", "Adv Acc"]

# Adding rows
for label, x, y in zip(labels, x_values, y_values):
    table.add_row([label, x, y])

# Printing the table
print(table)


+------------+--------------------+----------------------+
|  Setting   |      Test Acc      |       Adv Acc        |
+------------+--------------------+----------------------+
|   FedAvg   | 0.8560000658035278 | 0.01100000087171793  |
|    FAT     | 0.8080000281333923 | 0.36100003123283386  |
| FedAvg TM  | 0.8210000395774841 | 0.013000000268220901 |
|   FAT TM   | 0.7910000085830688 | 0.39400002360343933  |
| FedAvg Med | 0.815000057220459  | 0.09300000220537186  |
|  FAT Med   | 0.7950000166893005 |  0.5430000424385071  |
+------------+--------------------+----------------------+
