In [None]:
import torch, warnings, sys, time
import numpy as np, pandas as pd
from sklearn.metrics import accuracy_score,f1_score
from sklearn.model_selection import LeaveOneGroupOut
from config import cols, labels
from train_functions.MLPLogisticFL import MLPLogisticFL
from train_functions.fncs import select_feature
from collections import defaultdict
import copy
import os

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
seed = 0
type = 'Wesad'
df = pd.read_csv(f'./datas/{type}.csv',index_col=0)
print(df.shape)
logo = LeaveOneGroupOut()
groups = df['pnum']
unique_groups = np.unique(groups)
group_count = len(unique_groups)
print(group_count)

personal = 100
hidden = 32
epochs = 1
lr = 1e-3
rounds = 100
L = 0.1
participation_ratio = 0.8
import math

def compute_tau(C, Cs, R, tau_min=0.5, tau_max=0.9, delta=0.5):
    progress = (C - delta * (C - Cs)) / R
    progress = min(max(progress, 0), 1)  # clamp to [0,1]
    cosine = math.cos(math.pi * progress)
    tau = tau_max - 0.5 * (tau_max - tau_min) * (1 + cosine)
    return tau

def fl_run() :
    accs = []
    for i, (train_idx, test_idx) in enumerate(logo.split(df, df[labels[type]], groups)):
        
        x_data = df.drop(columns=cols[type])
        y_data = df[labels[type]]
        x_train, x_test = np.array(x_data.iloc[train_idx]), np.array(x_data.iloc[test_idx])
        y_train, y_test = np.array(y_data.iloc[train_idx]).reshape(-1,1), np.array(y_data.iloc[test_idx]).reshape(-1,1)
        pnum_train = groups.iloc[train_idx].values
        pnum_test = groups.iloc[test_idx].values
        # test 개인의 일부 제거
        x_test, y_test = x_test[personal:], y_test[personal:]

        # FedAvg를 위한 클라이언트 분할
        client_ids = np.unique(pnum_train)
        client_models = {cid:MLPLogisticFL(x_train.shape[1], hidden, epochs, lr) for cid in client_ids}
        client_participation = {cid: 0 for cid in client_ids}
        client_data_sizes = {cid: (pnum_train == cid).sum() for cid in client_ids}
        total_data_size = sum(client_data_sizes.values())
        client_weights = {cid: client_data_sizes[cid] / total_data_size for cid in client_ids}

        client_params = {}
        
        # 로컬 학습
        global_model = MLPLogisticFL(x_train.shape[1], hidden, epochs, lr)
        torch.save(global_model.state_dict(),'save.pt')
        accs.append([])
        for round in range(rounds) :
            num_clients = len(client_ids)
            selected_clients = np.random.choice(
                client_ids, size=int(num_clients * participation_ratio), replace=False
            )

            for cid in selected_clients:
                bef = client_models[cid].t
                client_models[cid].t = compute_tau(round,client_participation[cid],rounds)
                client_participation[cid] += 1
                idx = np.where(pnum_train == cid)[0]
                x_local, y_local = x_train[idx], y_train[idx]
                k = int(len(x_local)*L)
                x_local2 = x_local[k:]
                x_local,y_local = x_local[:k],y_local[:k]
                client_models[cid].load_state_dict(torch.load('save.pt'))
                client_models[cid].fit(x_local,x_local2, y_local)
                client_params[cid] = {k: v.detach().clone() for k, v in client_models[cid].state_dict().items()}
            
            # FedAvg: 파라미터 평균
            global_params = defaultdict(lambda: 0)
            for k in client_params[selected_clients[0]]:
                for cid in selected_clients:
                    global_params[k] += client_params[cid][k] * client_weights[cid]
            
            # 글로벌 모델 생성 및 파라미터 반영
            
            global_model.load_state_dict(global_params)
            torch.save(global_model.state_dict(),'save.pt')

            logits = global_model.forward(torch.Tensor(x_test))
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(probs, dim=1).numpy()
            accs[-1].append(accuracy_score(y_test, pred))
        
        #f1s.append(f1_score(y_test, pred))

        bar = '[' + '=' * int((i + 1) * 31 / group_count) + ' ' * (31 - int((i + 1) * 31 / group_count)) + ']'
        percent = ((i + 1) / group_count) * 100
        sys.stdout.write(f'\rProgress: {bar} {percent:.2f}% ({(i + 1)}/{group_count})')
        sys.stdout.flush()

    return accs

In [None]:
rets = []
for i in range(5) :
    rets.append(fl_run())

In [None]:
for data in np.mean(rets,axis=0) :
    plt.plot(data)

In [None]:
#for times in rets :
import matplotlib.pyplot as plt
plt.plot(np.mean(np.mean(rets,axis=0),axis=0))

In [None]:
acc_times = np.mean(np.mean(rets,axis=1),axis=-1)
print(f'{np.mean(acc_times):.4f} {np.std(acc_times):.4f}')

In [None]:
print(np.mean(np.mean(rets,axis=-2),axis=0))