In [None]:
import os
import sys

from utils import get_round_loss_score, train_one_epoch_output

sys.path.append("../")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

import seaborn as sns

import torch
import torch.nn as nn
import random

from src.Metamodel.data.dataloader import create_train_val_loader
from src.Metamodel.models.depthwiseNet import DepthNet

import torch.backends.cudnn as cudnn

from sklearn.metrics import confusion_matrix, auc, roc_auc_score, f1_score, classification_report, cohen_kappa_score
from torchmetrics.classification import MulticlassAUROC

import matplotlib.pyplot as plt
import seaborn as sns

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
DATAPATH = "../data/new"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DepthNet(lengths=30, patch_size=30, in_chans=2, embed_dim=256, norm_layer=None, output_dim=3).to(device)

weightType = "macro" # macro weighted

In [None]:
epochs = 10
batch_size = 512
want_round = 5
sameround = True 

In [None]:
local_train_results = []
local_test_results = []

local_test_auc = []
local_test_f1 = []
local_test_acc = []

local_train_auc = []
local_train_f1 = []
local_train_acc = []

In [None]:
for client_idx in range(1,11):

    # print(f"Client {client_idx}")

    torch.manual_seed(0)
    train_set, validation_set = create_train_val_loader(DATAPATH, batch_size, length=30,
                                                        meta_train_client_idx_lst=[client_idx], FLtrain=True)
    
    if type(client_idx) == list:
        train = train_set

    else:
        train, test = torch.utils.data.random_split(train_set, 
                                                    [int(len(train_set)*0.8), len(train_set)-int(len(train_set)*0.8)],
                                                    generator=torch.Generator().manual_seed(0))
        
    if type(client_idx) == list:
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
        valid_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)
        
    else:
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
        valid_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)
        
    # print("------------------------------------------------------")
    # print("Train", len(train))
    # print("Test", len(test))
    # print("Total", len(train)+len(test))
    # print("------------------------------------------------------")

    model = DepthNet(lengths=30, patch_size=30, in_chans=2, embed_dim=256, norm_layer=None, output_dim=3).to(device)

    lr = 0.0001

    optimizer_outer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_outer, step_size=40, gamma=0.5, last_epoch=-1, verbose=False)
    criterion = nn.CrossEntropyLoss().to(device)

    best_acc = 0
    best_f1 = 0
    best_auc = 0
    auroc = MulticlassAUROC(num_classes=3, average=weightType).to(device)

    finish = False
    early_epoch = 1
    train_result_dic = {"loss": [], "acc": [], "f1": [], "auc": [], "confusion_matrix": []}
    test_result_dic = {"loss": [], "acc": [], "f1": [], "auc": [], "confusion_matrix": []}
    
    temp_local_test_acc = []
    temp_local_test_f1= []
    temp_local_test_auc= []
    
    temp_local_train_acc = []
    temp_local_train_f1= []
    temp_local_train_auc= []

    for epoch in range(epochs):
        epoch += 1
        model.train()
        running_loss = 0.0
        correct = 0

        train_pred = []
        train_real = []
        train_proba = []
        
        for data in train_loader:            
            x_data, stage = data[0].to(device), data[1].to(device)
            
            pred_value, _ = model(x_data)
            pred = torch.argmax(pred_value, dim=1)
            pred_proba = torch.sigmoid(pred_value)

            correct += torch.sum(pred==stage).item()

            train_pred.extend(pred.detach().cpu().numpy())
            train_real.extend(stage.detach().cpu().numpy())
            train_proba.extend(pred_proba.detach().cpu().numpy())

            loss = criterion(pred_value, stage)

            optimizer_outer.zero_grad()
            loss.backward()
            optimizer_outer.step()
            
            running_loss += loss.item()
        
        auc = auroc(torch.tensor(train_proba), torch.tensor(train_real))
        acc = correct / len(train_loader.dataset)
        running_loss /= len(train_loader)
        f1score = f1_score(train_real, train_pred, average=weightType)

        scheduler.step()

        train_result_dic["loss"].append(running_loss)
        train_result_dic["acc"].append(acc)
        train_result_dic["f1"].append(f1score)
        train_result_dic["auc"].append(auc.numpy())
        train_result_dic["confusion_matrix"].append(confusion_matrix(train_pred, train_real))
        
        temp_local_train_acc.append(acc)
        temp_local_train_f1.append(f1score)
        temp_local_train_auc.append(auc.numpy())

        correct = 0
        valid_loss = 0.0
        
        model.eval()

        val_pred = []
        val_real = []
        val_proba = []

        with torch.no_grad():
            for data in test_loader: # valid_loader
                x_data, stage = data[0].to(device), data[1].to(device)
                
                # pred_value, _ = model(x_data)
                pred_value, _ = model(x_data)
                pred = torch.argmax(pred_value, dim=1)
                pred_proba = torch.sigmoid(pred_value)

                correct += torch.sum(pred==stage).item()

                val_pred.extend(pred.detach().cpu().numpy())
                val_real.extend(stage.detach().cpu().numpy())
                val_proba.extend(pred_proba.detach().cpu().numpy())

                loss = criterion(pred_value, stage)
                valid_loss += loss.item()
            
            auc = auroc(torch.tensor(val_proba), torch.tensor(val_real))
            acc = correct / len(test_loader.dataset) # valid_loader.dataset
            valid_loss /= len(test_loader) # valid_loader
            f1score = f1_score(val_real, val_pred, average=weightType)            
                
            if best_f1 < f1score:
                best_f1 = f1score
                best_epoch = epoch

            test_result_dic["loss"].append(valid_loss)
            test_result_dic["acc"].append(acc)
            test_result_dic["f1"].append(f1score)
            test_result_dic["auc"].append(auc.numpy())
            test_result_dic["confusion_matrix"].append(confusion_matrix(val_pred, val_real))
            
            temp_local_test_acc.append(acc)
            temp_local_test_auc.append(auc.numpy())
            temp_local_test_f1.append(f1score)
            temp_local_train_acc.append(acc)
            temp_local_train_auc.append(auc.numpy())
            temp_local_train_f1.append(f1score)

    local_train_acc.append(temp_local_train_acc)
    local_train_auc.append(temp_local_train_auc)
    local_train_f1.append(temp_local_train_f1)
    
    local_test_acc.append(temp_local_test_acc)
    local_test_auc.append(temp_local_test_auc)
    local_test_f1.append(temp_local_test_f1)
    
    local_train_results.append(train_result_dic)
    local_test_results.append(test_result_dic)
    # print("###############################################################")

In [None]:
local_result = [local_train_results, local_test_results]

In [None]:
random_train_auc = []
random_train_f1 = []
random_train_acc = []

random_test_auc = []
random_test_f1 = []
random_test_acc = []

meta_train_auc = []
meta_train_f1 = []
meta_train_acc = []

meta_test_auc = []
meta_test_f1 = []
meta_test_acc = []

normal_train_auc = []
normal_train_f1 = []
normal_train_acc = []

normal_test_auc = []
normal_test_f1 = []
normal_test_acc = []

In [None]:
clients_results_meta = []

# meta_try4,8 is best

MODE = "meta_try8" # "meta", "normal", "random"
MODELPATH = f"../src/FedMeta/model_file_cache/{MODE}"

for client in range(1,11):
    train_result, test_result = get_round_loss_score(DATAPATH, MODELPATH, model, device, client, weightedType=weightType, want_round=want_round)
    
    meta_train_auc.append(train_result["auc"])
    meta_train_f1.append(train_result["f1"])
    meta_train_acc.append(train_result["acc"])
    
    meta_test_auc.append(test_result["auc"])
    meta_test_f1.append(test_result["f1"])
    meta_test_acc.append(test_result["acc"])
    
    clients_results_meta.append([train_result, test_result])

In [None]:
clients_results_normal = []

MODE = "normal" # "meta", "normal", "random"
MODELPATH = f"../src/FedMeta/model_file_cache/{MODE}"

for client in range(1,11):
    train_result, test_result = get_round_loss_score(DATAPATH, MODELPATH, model, device, client, weightedType=weightType, want_round=want_round)
    
    
    normal_train_auc.append(train_result["auc"])
    normal_train_f1.append(train_result["f1"])
    normal_train_acc.append(train_result["acc"])
    
    normal_test_auc.append(test_result["auc"])
    normal_test_f1.append(test_result["f1"])
    normal_test_acc.append(test_result["acc"])
    
    clients_results_normal.append([train_result, test_result])

In [None]:
clients_results_random = []

MODE = "random" # "meta", "normal", "random"
MODELPATH = f"../src/FedMeta/model_file_cache/{MODE}"

for client in range(1,11):
    train_result, test_result = get_round_loss_score(DATAPATH, MODELPATH, model, device, client, weightedType=weightType, want_round=want_round)

    random_train_auc.append(train_result["auc"])
    random_train_f1.append(train_result["f1"])
    random_train_acc.append(train_result["acc"])
    
    random_test_auc.append(test_result["auc"])
    random_test_f1.append(test_result["f1"])
    random_test_acc.append(test_result["acc"])    
    
    clients_results_random.append([train_result, test_result])

In [None]:
best_rounds_dic = {"random": [], "normal": [], "meta": [], "personalized": [],}

for client_idx in range(10):
    
    fig, ax = plt.subplots(2,4, figsize=(15,5),)
    
    for train_test_idx in range(2):
        
        learning_curve_1 = clients_results_random[client_idx][train_test_idx]["loss"][:want_round]
        auc_curve_1 = clients_results_random[client_idx][train_test_idx]["auc"][:want_round]
        f1_curve_1 = clients_results_random[client_idx][train_test_idx]["f1"][:want_round]
        acc_curve_1 = clients_results_random[client_idx][train_test_idx]["acc"][:want_round]
        kappa_curve_1 = clients_results_random[client_idx][train_test_idx]["kappa"][:want_round]
        
        learning_curve_2 = clients_results_normal[client_idx][train_test_idx]["loss"][:want_round]
        auc_curve_2 = clients_results_normal[client_idx][train_test_idx]["auc"][:want_round]
        f1_curve_2 = clients_results_normal[client_idx][train_test_idx]["f1"][:want_round]
        acc_curve_2 = clients_results_normal[client_idx][train_test_idx]["acc"][:want_round]
        kappa_curve_2 = clients_results_normal[client_idx][train_test_idx]["kappa"][:want_round]
        
        learning_curve_3 = clients_results_meta[client_idx][train_test_idx]["loss"][:want_round]
        auc_curve_3 = clients_results_meta[client_idx][train_test_idx]["auc"][:want_round]
        f1_curve_3 = clients_results_meta[client_idx][train_test_idx]["f1"][:want_round]
        acc_curve_3 = clients_results_meta[client_idx][train_test_idx]["acc"][:want_round]
        kappa_curve_3 = clients_results_meta[client_idx][train_test_idx]["kappa"][:want_round]

        local_learning_curve = local_result[train_test_idx][client_idx]["loss"][:want_round]
        local_auc_curve = local_result[train_test_idx][client_idx]["auc"][:want_round]
        local_f1_curve = local_result[train_test_idx][client_idx]["f1"][:want_round]
        local_acc_curve = local_result[train_test_idx][client_idx]["acc"][:want_round]

        ax[train_test_idx][0].plot(local_learning_curve, label="Personalized")    
        ax[train_test_idx][0].plot(learning_curve_1, label="Random")
        ax[train_test_idx][0].plot(learning_curve_2, label="Normal")
        ax[train_test_idx][0].plot(learning_curve_3, label="Meta")
        
        ax[train_test_idx][1].plot(local_auc_curve, label="Personalized")
        ax[train_test_idx][1].plot(auc_curve_1, label="Random")
        ax[train_test_idx][1].plot(auc_curve_2, label="Normal")
        ax[train_test_idx][1].plot(auc_curve_3, label="Meta")

        ax[train_test_idx][2].plot(local_f1_curve, label="Personalized")        
        ax[train_test_idx][2].plot(f1_curve_1, label="Random")
        ax[train_test_idx][2].plot(f1_curve_2, label="Normal")
        ax[train_test_idx][2].plot(f1_curve_3, label="Meta")
        
        ax[train_test_idx][3].plot(local_acc_curve, label="Personalized")
        ax[train_test_idx][3].plot(acc_curve_1, label="Random")
        ax[train_test_idx][3].plot(acc_curve_2, label="Normal")
        ax[train_test_idx][3].plot(acc_curve_3, label="Meta")
        
        if train_test_idx == 0:
            fig.suptitle(f"(Train) Client {client_idx+1}")
            
        else:
            fig.suptitle(f"(Test) Client {client_idx+1}")
            best_rounds_dic["personalized"].append(np.where(local_result[1][client_idx]["f1"] == np.max(local_result[1][client_idx]["f1"]))[0][0]) 
            best_rounds_dic["random"].append(np.where(clients_results_random[client_idx][1]["f1"] == np.max(clients_results_random[client_idx][1]["f1"]))[0][0]) 
            best_rounds_dic["normal"].append(np.where(clients_results_normal[client_idx][1]["f1"] == np.max(clients_results_normal[client_idx][1]["f1"]))[0][0])
            best_rounds_dic["meta"].append(np.where(clients_results_meta[client_idx][1]["f1"] == np.max(clients_results_meta[client_idx][1]["f1"]))[0][0])
        
        ax[train_test_idx][0].set_ylabel("Loss")
        ax[train_test_idx][1].set_ylabel("AUC")
        ax[train_test_idx][2].set_ylabel("F1 score")
        ax[train_test_idx][3].set_ylabel("Accuracy")

        ax[train_test_idx,0].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        ax[train_test_idx,1].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        ax[train_test_idx,2].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        ax[train_test_idx,3].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))

        ax[train_test_idx,0].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
        ax[train_test_idx,1].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
        ax[train_test_idx,2].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
        ax[train_test_idx,3].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
    
    
    plt.suptitle(f"Client {client_idx+1}",)
    plt.tight_layout()

    fig.text(0.5, -0.01, "Epoch / Round",)
    plt.legend(loc=(-0.7, 2.3), ncol=4)
    plt.show()

In [None]:
best_performance = []

for client_idx in range(10):
        
    learning_curve_1 = clients_results_random[client_idx][1]["loss"][best_rounds_dic["random"][client_idx]]
    auc_curve_1 = clients_results_random[client_idx][1]["auc"][best_rounds_dic["random"][client_idx]]
    f1_curve_1 = clients_results_random[client_idx][1]["f1"][best_rounds_dic["random"][client_idx]]
    acc_curve_1 = clients_results_random[client_idx][1]["acc"][best_rounds_dic["random"][client_idx]]
    kappa_curve_1 = clients_results_random[client_idx][1]["kappa"][best_rounds_dic["random"][client_idx]]
    
    learning_curve_2 = clients_results_normal[client_idx][1]["loss"][best_rounds_dic["normal"][client_idx]]
    auc_curve_2 = clients_results_normal[client_idx][1]["auc"][best_rounds_dic["normal"][client_idx]]
    f1_curve_2 = clients_results_normal[client_idx][1]["f1"][best_rounds_dic["normal"][client_idx]]
    acc_curve_2 = clients_results_normal[client_idx][1]["acc"][best_rounds_dic["normal"][client_idx]]
    kappa_curve_2 = clients_results_normal[client_idx][1]["kappa"][best_rounds_dic["normal"][client_idx]]
    
    learning_curve_3 = clients_results_meta[client_idx][1]["loss"][best_rounds_dic["meta"][client_idx]]
    auc_curve_3 = clients_results_meta[client_idx][1]["auc"][best_rounds_dic["meta"][client_idx]]
    f1_curve_3 = clients_results_meta[client_idx][1]["f1"][best_rounds_dic["meta"][client_idx]]
    acc_curve_3 = clients_results_meta[client_idx][1]["acc"][best_rounds_dic["meta"][client_idx]]
    kappa_curve_3 = clients_results_meta[client_idx][1]["kappa"][best_rounds_dic["meta"][client_idx]]

    local_learning_curve = local_result[1][client_idx]["loss"][best_rounds_dic["personalized"][client_idx]]
    local_auc_curve = local_result[1][client_idx]["auc"][best_rounds_dic["personalized"][client_idx]]
    local_f1_curve = local_result[1][client_idx]["f1"][best_rounds_dic["personalized"][client_idx]]
    local_acc_curve = local_result[1][client_idx]["acc"][best_rounds_dic["personalized"][client_idx]]

    best_performance.append([[local_learning_curve, learning_curve_1, learning_curve_2, learning_curve_3],
                             [local_auc_curve, auc_curve_1, auc_curve_2, auc_curve_3],
                             [local_f1_curve, f1_curve_1, f1_curve_2, f1_curve_3],
                             [local_acc_curve, acc_curve_1, acc_curve_2, acc_curve_3]])

In [None]:
# best_rounds_dic = {"random": {"acc":[], "f1":[], "auc":[]},
#                    "normal": {"acc":[], "f1":[], "auc":[]},
#                    "meta": {"acc":[], "f1":[], "auc":[]},
#                    "personalized": {"acc":[], "f1":[], "auc":[]},}

In [None]:
# best_rounds_dic = {"random": [],
#                    "normal": [],
#                    "meta": [],
#                    "personalized": [],}


# for client_idx in range(10):
#     best_rounds_dic["personalized"].append(np.where(local_result[1][client_idx]["f1"] == np.max(local_result[1][client_idx]["f1"]))[0][0]) 
#     best_rounds_dic["random"].append(np.where(clients_results_random[client_idx][1]["f1"] == np.max(clients_results_random[client_idx][1]["f1"]))[0][0]) 
#     best_rounds_dic["normal"].append(np.where(clients_results_normal[client_idx][1]["f1"] == np.max(clients_results_normal[client_idx][1]["f1"]))[0][0])
#     best_rounds_dic["meta"].append(np.where(clients_results_meta[client_idx][1]["f1"] == np.max(clients_results_meta[client_idx][1]["f1"]))[0][0])

In [None]:
best_performance = []

for client_idx in range(10):
        
    learning_curve_1 = clients_results_random[client_idx][1]["loss"][best_rounds_dic["random"][client_idx]]
    auc_curve_1 = clients_results_random[client_idx][1]["auc"][best_rounds_dic["random"][client_idx]]
    f1_curve_1 = clients_results_random[client_idx][1]["f1"][best_rounds_dic["random"][client_idx]]
    acc_curve_1 = clients_results_random[client_idx][1]["acc"][best_rounds_dic["random"][client_idx]]
    kappa_curve_1 = clients_results_random[client_idx][1]["kappa"][best_rounds_dic["random"][client_idx]]
    
    learning_curve_2 = clients_results_normal[client_idx][1]["loss"][best_rounds_dic["normal"][client_idx]]
    auc_curve_2 = clients_results_normal[client_idx][1]["auc"][best_rounds_dic["normal"][client_idx]]
    f1_curve_2 = clients_results_normal[client_idx][1]["f1"][best_rounds_dic["normal"][client_idx]]
    acc_curve_2 = clients_results_normal[client_idx][1]["acc"][best_rounds_dic["normal"][client_idx]]
    kappa_curve_2 = clients_results_normal[client_idx][1]["kappa"][best_rounds_dic["normal"][client_idx]]
    
    learning_curve_3 = clients_results_meta[client_idx][1]["loss"][best_rounds_dic["meta"][client_idx]]
    auc_curve_3 = clients_results_meta[client_idx][1]["auc"][best_rounds_dic["meta"][client_idx]]
    f1_curve_3 = clients_results_meta[client_idx][1]["f1"][best_rounds_dic["meta"][client_idx]]
    acc_curve_3 = clients_results_meta[client_idx][1]["acc"][best_rounds_dic["meta"][client_idx]]
    kappa_curve_3 = clients_results_meta[client_idx][1]["kappa"][best_rounds_dic["meta"][client_idx]]

    local_learning_curve = local_result[1][client_idx]["loss"][best_rounds_dic["personalized"][client_idx]]
    local_auc_curve = local_result[1][client_idx]["auc"][best_rounds_dic["personalized"][client_idx]]
    local_f1_curve = local_result[1][client_idx]["f1"][best_rounds_dic["personalized"][client_idx]]
    local_acc_curve = local_result[1][client_idx]["acc"][best_rounds_dic["personalized"][client_idx]]

    best_performance.append([[local_learning_curve, learning_curve_1, learning_curve_2, learning_curve_3],
                             [local_auc_curve, auc_curve_1, auc_curve_2, auc_curve_3],
                             [local_f1_curve, f1_curve_1, f1_curve_2, f1_curve_3],
                             [local_acc_curve, acc_curve_1, acc_curve_2, acc_curve_3]])

In [None]:

# 그림 사이즈, 바 굵기 조정
bar_width = 0.2

# 각 연도별로 3개 샵의 bar를 순서대로 나타내는 과정, 각 그래프는 0.25의 간격을 두고 그려짐
for performance, perform_name in zip(range(1,4), ["AUC", "F1 score", "Accuracy"]): # 1=auc, 2=f1, 3=acc
    fig, ax = plt.subplots(figsize=(12,6))

    for client_idx in range(10):
        for i, name, color in zip(range(4), ["local", "random", "normal", "meta"], ["blue", "orange", "green", "red"]):
            if client_idx == 0:
                plt.bar(client_idx + bar_width*i, best_performance[client_idx][performance][i], bar_width, color=color, label=name)
            else:
                plt.bar(client_idx + bar_width*i, best_performance[client_idx][performance][i], bar_width, color=color)

    # x축 위치를 정 가운데로 조정하고 x축의 텍스트를 year 정보와 매칭
    plt.xticks(np.arange(bar_width+bar_width/2, 10 + bar_width+bar_width/2, 1), np.arange(1,11,1))

    # x축, y축 이름 및 범례 설정
    plt.xlabel('Client', size = 13)
    plt.ylabel(f'{perform_name}', size = 13)
    if perform_name == "AUC" or perform_name == "Accuracy":
        plt.ylim(0.4, 0.85)
    else:
        plt.ylim(0.3, 0.70)

    plt.legend()
    plt.show()

In [None]:
fig, ax = plt.subplots(1,4, figsize=(20,3),)

ax[0].plot(np.array(local_test_auc).mean(axis=0)[:want_round],label="personalized", color="blue")
ax[0].fill_between(np.arange(want_round),
                 np.array(local_test_auc).mean(axis=0)[:want_round]-np.array(local_test_auc).std(axis=0)[:want_round],
                 np.array(local_test_auc).mean(axis=0)[:want_round]+np.array(local_test_auc).std(axis=0)[:want_round],
                 color="blue", alpha=0.1,)


ax[0].plot(np.array(normal_test_auc).mean(axis=0),label="normal", color="green")
ax[0].fill_between(np.arange(want_round),
                 np.array(normal_test_auc).mean(axis=0)-np.array(normal_test_auc).std(axis=0),
                 np.array(normal_test_auc).mean(axis=0)+np.array(normal_test_auc).std(axis=0),
                 color="green", alpha=0.1,)


ax[0].plot(np.array(random_test_auc).mean(axis=0),label="random", color="orange")
ax[0].fill_between(np.arange(want_round),
                 np.array(random_test_auc).mean(axis=0)-np.array(random_test_auc).std(axis=0),
                 np.array(random_test_auc).mean(axis=0)+np.array(random_test_auc).std(axis=0),
                 color="orange", alpha=0.1,)


ax[0].plot(np.array(meta_test_auc).mean(axis=0),label="meta", color="red")
ax[0].fill_between(np.arange(want_round),
                 np.array(meta_test_auc).mean(axis=0)-np.array(meta_test_auc).std(axis=0),
                 np.array(meta_test_auc).mean(axis=0)+np.array(meta_test_auc).std(axis=0),
                 color="red", alpha=0.1,)
ax[0].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
ax[0].set_ylim(0.5, 0.8)
ax[0].set_ylabel("AUC")
###############################################################################################

ax[1].plot(np.array(local_test_f1).mean(axis=0)[:want_round],label="personalized", color="blue")
ax[1].fill_between(np.arange(want_round),
                 np.array(local_test_f1).mean(axis=0)[:want_round]-np.array(local_test_f1).std(axis=0)[:want_round],
                 np.array(local_test_f1).mean(axis=0)[:want_round]+np.array(local_test_f1).std(axis=0)[:want_round],
                 color="blue", alpha=0.1,)


ax[1].plot(np.array(normal_test_f1).mean(axis=0),label="normal", color="green")
ax[1].fill_between(np.arange(want_round),
                 np.array(normal_test_f1).mean(axis=0)-np.array(normal_test_f1).std(axis=0),
                 np.array(normal_test_f1).mean(axis=0)+np.array(normal_test_f1).std(axis=0),
                 color="green", alpha=0.1,)


ax[1].plot(np.array(random_test_f1).mean(axis=0),label="random", color="orange")
ax[1].fill_between(np.arange(want_round),
                 np.array(random_test_f1).mean(axis=0)-np.array(random_test_f1).std(axis=0),
                 np.array(random_test_f1).mean(axis=0)+np.array(random_test_f1).std(axis=0),
                 color="orange", alpha=0.1,)


ax[1].plot(np.array(meta_test_f1).mean(axis=0),label="meta", color="red")
ax[1].fill_between(np.arange(want_round),
                 np.array(meta_test_f1).mean(axis=0)-np.array(meta_test_f1).std(axis=0),
                 np.array(meta_test_f1).mean(axis=0)+np.array(meta_test_f1).std(axis=0),
                 color="red", alpha=0.1,)
ax[1].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
ax[1].set_ylim(0.2, 0.6)
ax[1].set_ylabel("F1 Score")
###############################################################################################
ax[3].plot(np.array(local_test_acc).mean(axis=0)[:want_round],label="personalized", color="blue")
ax[3].fill_between(np.arange(want_round),
                 np.array(local_test_acc).mean(axis=0)[:want_round]-np.array(local_test_acc).std(axis=0)[:want_round],
                 np.array(local_test_acc).mean(axis=0)[:want_round]+np.array(local_test_acc).std(axis=0)[:want_round],
                 color="blue", alpha=0.1,)


ax[3].plot(np.array(normal_test_acc).mean(axis=0),label="normal", color="green")
ax[3].fill_between(np.arange(want_round),
                 np.array(normal_test_acc).mean(axis=0)-np.array(normal_test_acc).std(axis=0),
                 np.array(normal_test_acc).mean(axis=0)+np.array(normal_test_acc).std(axis=0),
                 color="green", alpha=0.1,)


ax[3].plot(np.array(random_test_acc).mean(axis=0),label="random", color="orange")
ax[3].fill_between(np.arange(want_round),
                 np.array(random_test_acc).mean(axis=0)-np.array(random_test_acc).std(axis=0),
                 np.array(random_test_acc).mean(axis=0)+np.array(random_test_acc).std(axis=0),
                 color="orange", alpha=0.1,)


ax[3].plot(np.array(meta_test_acc).mean(axis=0),label="meta", color="red")
ax[3].fill_between(np.arange(want_round),
                 np.array(meta_test_acc).mean(axis=0)-np.array(meta_test_acc).std(axis=0),
                 np.array(meta_test_acc).mean(axis=0)+np.array(meta_test_acc).std(axis=0),
                 color="red", alpha=0.1,)

ax[3].set_xticks(np.arange(0, want_round, 1), np.arange(1, want_round+1, 1))
ax[3].set_ylim(0.3, 1)
ax[3].set_ylabel("Accuracy")

plt.xlabel("Epoch / Round")

plt.show()