#### Data Source:
UCI Census Income Data Set: http://archive.ics.uci.edu/ml/datasets/Census+Income  

#### Pre-configuration

In [1]:
# data processing module
import numpy as np
import pandas as pd
from utils_sub import *

# federated learning module
from models import *
from FedAvg import *

# others
import os
import torch

# plot module
# import seaborn as sns

In [2]:
# set up computing resources
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Non-iid sampling to generate subsets

In [3]:
# Load data
train_set, test_set = dataloader_adult()

In [4]:
num_clients = 5

In [5]:
train_noniid = load_noniid(num_clients, train_set)

#### Fuzzify Adult data

In [6]:
from fuzzyset import FuzzySet
from fuzzification import FuzzyData

In [7]:
def load_noniid_client(train_noniid, test_set):
    data = []
    for subset in train_noniid:
        df = pd.DataFrame(subset)
        df.columns = [*df.columns[:-1], 'target']
        # fuzzy train data subset
        train_data = FuzzyData(data = df, target = 'target')
        train_data.quantile_fuzzification()
        # get train epistemic values
        epi_train_data = train_data.get_epistemic_values().values
        # get train data labels
        train_label = df['target'].values
        # combine train data subsets
        data.append((epi_train_data, train_label))
        
    # fuzzy test data subset  
    test_data = FuzzyData(data = test_set, target = 'target')
    test_data.quantile_fuzzification()
    # get test epistemic values
    epi_test_data = test_data.get_epistemic_values().values
    # get test data labels
    test_label = test_set['target'].values
    
    data.append((epi_test_data, test_label))
    
    return data

In [8]:
data = load_noniid_client(train_noniid, test_set)

In [9]:
lr = 0.001
fl_param = {
    'output_size': 3,
    'client_num': num_clients,
    'model': MLP,
    'data': data,
    'lr': lr,
    'epoch': 3,
    'C': 1,
    'sigma': 0.5,
    'clip': 2,
    'batch_size': 128,
    'device': device
}
import warnings
warnings.filterwarnings("ignore")
fl_entity = FedAvgServer(fl_param).to(device)

In [10]:
for e in range(50):
    if e+1 % 10 == 0:
        lr *= 0.1
        fl_entity.set_lr(lr)
    acc = fl_entity.global_update()
    print("global epochs = {:d}, acc = {:.4f}".format(e+1, acc))

global epochs = 1, acc = 0.1773
global epochs = 2, acc = 0.2061
global epochs = 3, acc = 0.2380
global epochs = 4, acc = 0.2678
global epochs = 5, acc = 0.3251
global epochs = 6, acc = 0.3622
global epochs = 7, acc = 0.4306
global epochs = 8, acc = 0.5039
global epochs = 9, acc = 0.5693
global epochs = 10, acc = 0.6481
global epochs = 11, acc = 0.7112
global epochs = 12, acc = 0.7346
global epochs = 13, acc = 0.7498
global epochs = 14, acc = 0.7579
global epochs = 15, acc = 0.7602
global epochs = 16, acc = 0.7617
global epochs = 17, acc = 0.7632
global epochs = 18, acc = 0.7638
global epochs = 19, acc = 0.7638
global epochs = 20, acc = 0.7638
global epochs = 21, acc = 0.7638
global epochs = 22, acc = 0.7638
global epochs = 23, acc = 0.7638
global epochs = 24, acc = 0.7638
global epochs = 25, acc = 0.7638
global epochs = 26, acc = 0.7638
global epochs = 27, acc = 0.7638
global epochs = 28, acc = 0.7638
global epochs = 29, acc = 0.7638
global epochs = 30, acc = 0.7638
global epochs = 31,