In [1]:
import numpy as np
import tensorflow as tf
import importlib

import random
import os
import sys

from client import Client
from utils.args import parse_args
from utils.model_utils import read_data
from utils.baseline_constants import ACCURACY_KEY

In [20]:
def get_card(a):
    return np.bincount(a).argmax()

# using hard-voting for now
def get_vote(x):
    a = x.transpose()
    b = np.apply_along_axis(get_card, 1, a)
    return b.reshape(-1)
    
def setup_clients(dataset, model=None, use_val_set=False):
    eval_set = 'test' if not use_val_set else 'val'
    train_data_dir = os.path.join('data', dataset, 'data', 'train')
    test_data_dir = os.path.join('data', dataset, 'data', eval_set)

    users, groups, train_data, test_data = read_data(train_data_dir, test_data_dir)
    if len(groups) == 0:
        groups = [[] for _ in users]    
    clients = [Client(u, g, train_data[u], test_data[u], model) for u, g in zip(users, groups)]

    return clients, train_data, test_data



In [21]:
model_path = '%s.%s' % ("femnist", "cnn")
print('############################## %s ##############################' % model_path)
mod = importlib.import_module(model_path)
ClientModel = getattr(mod, 'ClientModel')



############################## femnist.cnn ##############################


In [27]:
path = "../fedmc/models/checkpoints/femnist"

K = 3
lr = 0.01
num_class = 62

model_initlization = ClientModel(123456, lr, num_class)
ens_models = []
for i in range(K):
    model = ClientModel(random.randint(1, 10000), lr, num_class)
    model.load_ckp(os.path.join(path, "K-{}".format(K), "K{}-C{}.ckpt".format(K, i+1)))
    ens_models.append(model)


INFO:tensorflow:Restoring parameters from ../fedmc/models/checkpoints/femnist/K-5/K5-C1.ckpt
INFO:tensorflow:Restoring parameters from ../fedmc/models/checkpoints/femnist/K-5/K5-C2.ckpt
INFO:tensorflow:Restoring parameters from ../fedmc/models/checkpoints/femnist/K-5/K5-C3.ckpt
INFO:tensorflow:Restoring parameters from ../fedmc/models/checkpoints/femnist/K-5/K5-C4.ckpt
INFO:tensorflow:Restoring parameters from ../fedmc/models/checkpoints/femnist/K-5/K5-C5.ckpt


In [9]:
clients, train_data, test_data = setup_clients("femnist", model_initlization)
# pred0 = ensemble0.test(clients[0].eval_data)["output"]
# pred1 = ensemble1.test(clients[0].eval_data)["output"]
# pred2 = ensemble2.test(clients[0].eval_data)["output"]


In [23]:
def strong_learn_pred(idx):
    outcome = ens_models[0].test(clients[idx].eval_data)["output"]
    for i in range(1, K):
        pred = ens_models[i].test(clients[idx].eval_data)["output"]
        outcome = np.vstack([outcome, pred])
    data = clients[idx].eval_data
    labels = np.array(data["y"])
    correct = np.count_nonzero(np.equal(get_vote(outcome), labels))
    return correct
    
def fine_tune_pred(idx, m_id):
    model_initlization.set_params(ens_models[m_id].get_params())
    clients[idx].train(5)
    data = clients[idx].eval_data
    labels = np.array(data["y"])    
    pred =  model_initlization.test(clients[idx].eval_data)["output"]
    correct = np.count_nonzero(np.equal(pred, labels))
    return correct   

In [11]:
correct = np.count_nonzero(np.equal(get_vote(a), labels))
correct

NameError: name 'a' is not defined

# evalation on K=3

In [24]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    data = clients[n].eval_data
    labels = np.array(data["y"])
    #print(c.id, ",total:", len(labels))
    sum_ens += strong_learn_pred(n)
    tmp_list_fine = [fine_tune_pred(n, i) for i in range(K)]
    sum_fin += np.max(tmp_list_fine)
    
print(sum_ens)
print(sum_fin)

5502
5603


# evalation on K=4

In [26]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    sum_ens += strong_learn_pred(n)
    tmp_list_fine = [fine_tune_pred(n, i) for i in range(K)]
    sum_fin += np.max(tmp_list_fine)
    
print(sum_ens)
print(sum_fin)

5128
5319


# evalation on K=5

In [28]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    sum_ens += strong_learn_pred(n)
    tmp_list_fine = [fine_tune_pred(n, i) for i in range(K)]
    sum_fin += np.max(tmp_list_fine)
    
print(sum_ens)
print(sum_fin)

5322
5599


# evalation on K=6

In [66]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    sum_ens += strong_learn_pred(n)
    sum_fin += fine_tune_pred(n)
    
print(sum_ens)
print(sum_fin)

5181
4895


# evaluation on K=7

In [68]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    sum_ens += strong_learn_pred(n)
    sum_fin += fine_tune_pred(n)
    
print(sum_ens)
print(sum_fin)

5226
5359


# evaluation on K=8

In [None]:
sum_ens = 0
sum_fin = 0
for n,c in enumerate(clients[:100]):
    sum_ens += strong_learn_pred(n)
    sum_fin += fine_tune_pred(n)
    
print(sum_ens)
print(sum_fin)