In [1]:
import os
import sys
sys.path.append("../../")

import pandas as pandas
import numpy as np

import torch
import matplotlib.pyplot as plt

from copy import deepcopy

from utils import load_all_models, load_all_client_loader, get_client_result
from src.simulator.utils import get_client_dataset, get_key_by_value
from src.data.DataList import dataset_dict

import torch.nn as nn

import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

In [2]:
# {model_type : [best_model wandb id, best_model round]}
model_type_dict = {"Center": ["dandy-shadow-1", -1], # No need to specify epoch
                   "Local": ["skilled-bee-1", -1], # No need to specify epoch
                   "FedAvg" : ["fresh-feather-2", 53],
                   "FedProx": ["glorious-capybara-5", 88],  
                   "MOON": ["iconic-rain-9", 59]
                   }
BASE = "/NFS/Users/moonsh/AdaptFL/ckpt/"
device = torch.device("cuda:3")
batch_size, worker = 16, 0

In [None]:
model_dict, config = load_all_models(model_type_dict, BASE, device, batch_size, worker)


(train_client_list, 
 val_client_list, 
 test_client_list) = load_all_client_loader(config, _mode='all')

savepath = '/NFS/Users/moonsh/AdaptFL/result/'

In [4]:
result_dict = {}
for client_idx in range(len(train_client_list)):

    mae_dict = get_client_result(client_idx, model_dict, 
                                  train_client_list, val_client_list, test_client_list, 
                                  device, savepath, model_type_dict)
    result_dict[client_idx] = mae_dict

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(25, 10))

model_type = model_type_dict.keys()

for client_idx in range(len(train_client_list)):
    bar_list = []
    keys = []
    for idx, _type in enumerate(model_type):
        ax = axs[client_idx//5][client_idx%5]

        if _type == "Center":
            ax.axhline(result_dict[client_idx][_type], label=_type, color='red', linestyle='--', linewidth=2)
        elif _type == "Local":
            ax.axhline(result_dict[client_idx][_type], label=_type, color='blue', linestyle='--', linewidth=2)
        
        else:
            bar_list.append(result_dict[client_idx][_type])
            keys.append(_type)
    
    colors = sns.color_palette('Set3', len(bar_list))

    ax.bar(keys, bar_list, color=colors)
    ax.set_ylim(0, np.array(bar_list).max()*1.5)
    ax.set_xticklabels(keys, rotation=45, fontsize=15)
    ax.set_title(f"Client {client_idx} ({get_key_by_value(dataset_dict, client_idx)})", fontsize=20)
    ax.legend(fontsize=15, loc='upper right')


plt.tight_layout()
plt.show()