In [2]:
import time

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import json
# from model import Model
# from femnist.cnn import ClientModel
# from utils.model_utils import read_data
import pydash
import math

pd.set_option('display.max_columns', 500)

In [3]:
from collections import defaultdict
def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
#         for idx, u in enumerate(cdata['users']):
#             cluster_ids[u] = cdata['cluster_ids'][idx]
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
            print("hierarchies exist")
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
#     cluster_ids = [cluster_ids[c] for c in clients]
    return clients, groups, data

In [4]:
users, _, data,  = read_dir('/tf/work/tangle-learning/data/femnist-data/large/train')

In [5]:
_, _, data_test,  = read_dir('/tf/work/tangle-learning/data/femnist-data/large/test')

In [6]:
num_clusters = 3
digits_per_cluster = np.array_split(range(10), num_clusters)
users_per_cluster = (len(users) // num_clusters)
print(digits_per_cluster)

[array([0, 1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]


In [7]:
from copy import deepcopy
complete_data = deepcopy(data)
for i in range(users_per_cluster * num_clusters, len(users)):
    del complete_data[users[i]]
for i in range(users_per_cluster * num_clusters):
    complete_data[users[i]]['x'].extend(data_test[users[i]]['x'])
    complete_data[users[i]]['y'].extend(data_test[users[i]]['y'])
total_number_samples = np.sum([np.sum(np.array(value['y']) < 10) for (key, value) in complete_data.items()])
print(total_number_samples)

393581


In [34]:
def clean(data, cluster_leak=0.1):
    cleaned = {}
    cluster_ids = {}
    for user_index in range(len(complete_data.keys())):
        userdata = {}
        userdata['x'] = []
        userdata['y'] = []
        
        # Add cluster data
        for i in range(num_clusters):
            user_index_2 = (user_index + i * users_per_cluster) % len(complete_data.keys())
            mask = np.array([elem in digits_per_cluster[user_index // users_per_cluster] for elem in data[users[user_index_2]]['y']])
            userdata['x'].extend(np.array(data[users[user_index_2]]['x'])[mask].tolist())
            userdata['y'].extend(np.array(data[users[user_index_2]]['y'])[mask].tolist())
        
        # Add fraction of non-cluster data
        num_cluster_data = len(userdata['y'])
        num_digits_for_user = len(digits_per_cluster[user_index // users_per_cluster])
        num_additional_data = math.floor(num_cluster_data / num_digits_for_user * cluster_leak) * (10 - num_digits_for_user)
        for i in range(num_clusters):
            if len(userdata['y']) >= num_additional_data + num_cluster_data:
                break
            else:
                missing_additional_data = num_additional_data + num_cluster_data - len(userdata['y'])
                user_index_2 = (user_index + i * users_per_cluster) % len(complete_data.keys())
                mask = np.array([elem not in digits_per_cluster[user_index // users_per_cluster] for elem in data[users[user_index_2]]['y']])
                userdata['x'].extend(np.array(data[users[user_index_2]]['x'])[mask][:missing_additional_data].tolist())
                userdata['y'].extend(np.array(data[users[user_index_2]]['y'])[mask][:missing_additional_data].tolist())
                    
        # Shuffle
        shuffle_mask = np.random.permutation(len(userdata['y']))
        userdata['x'] = np.array(userdata['x'])[shuffle_mask].tolist()
        userdata['y'] = np.array(userdata['y'])[shuffle_mask].tolist()
                    
        assert len(userdata['y']) > 0, 'Not enough data for client {}'.format(users[user_index])
        
        cleaned[users[user_index]] = userdata
        cluster_ids[users[user_index]] = user_index // users_per_cluster
    return cleaned, cluster_ids

In [9]:
def split(cleaned_old, cluster_ids, test_split=0.1):
    cleaned = {}
    cleaned_test = {}
    for username in cluster_ids.keys():
        userdata = cleaned_old[username]
        train_size = math.floor(len(userdata['y']) * (1 - test_split))
        cleaned[username]= {'x': userdata['x'][:train_size],
                            'y': userdata['y'][:train_size]}
        cleaned_test[username]= {'x': userdata['x'][train_size:],
                                 'y': userdata['y'][train_size:]}
    return cleaned, cleaned_test, cluster_ids

In [33]:
cleaned_old, cluster_ids = clean(complete_data)



136
4
20
20


149
4
22
22


140
4
21
21


137
4
20
20


123
4
18
18


128
4
19
19


132
4
19
19


135
4
20
20


137
4
20
20


144
4
21
21


150
4
22
22


148
4
22
22


128
4
19
19


141
4
21
21


147
4
22
22


140
4
21
21


129
4
19
19


136
4
20
20


126
4
18
18


148
4
22
22


139
4
20
20


133
4
19
19


127
4
19
19


152
4
22
22


142
4
21
21


109
4
16
16


136
4
20
20


144
4
21
21


134
4
20
20


128
4
19
19


123
4
18
18


143
4
21
21


130
4
19
19


117
4
17
17


138
4
20
20


152
4
22
22


131
4
19
19


153
4
22
22


132
4
19
19


129
4
19
19


141
4
21
21


129
4
19
19


145
4
21
21


134
4
20
20


131
4
19
19


146
4
21
21


151
4
22
22


129
4
19
19


123
4
18
18


142
4
21
21


149
4
22
22


152
4
22
22


105
4
15
15


144
4
21
21


144
4
21
21


136
4
20
20


112
4
16
16


145
4
21
21


144
4
21
21


130
4
19
19


138
4
20
20


149
4
22
22


151
4
22
22


153
4
22
22


151
4
22
22


136
4
20
20


149
4
22
22


137
4
20
20


126
4
18
18


139
4
20
20


152
4
22
22


138




136
4
20
20


156
4
23
23


114
4
17
17


169
4
25
25


133
4
19
19


144
4
21
21


134
4
20
20


147
4
22
22


145
4
21
21


149
4
22
22


140
4
21
21


141
4
21
21


147
4
22
22


138
4
20
20


136
4
20
20


125
4
18
18


134
4
20
20


132
4
19
19


146
4
21
21


136
4
20
20


123
4
18
18


151
4
22
22


139
4
20
20


152
4
22
22


138
4
20
20


138
4
20
20


144
4
21
21


137
4
20
20


133
4
19
19


138
4
20
20


133
4
19
19


141
4
21
21


137
4
20
20


150
4
22
22


140
4
21
21


142
4
21
21


138
4
20
20


144
4
21
21


132
4
19
19


144
4
21
21


151
4
22
22


151
4
22
22


137
4
20
20


140
4
21
21


149
4
22
22


146
4
21
21


137
4
20
20


123
4
18
18


150
4
22
22


150
4
22
22


133
4
19
19


141
4
21
21


137
4
20
20


147
4
22
22


144
4
21
21


138
4
20
20


140
4
21
21


142
4
21
21


150
4
22
22


138
4
20
20


145
4
21
21


147
4
22
22


121
4
18
18


138
4
20
20


135
4
20
20


145
4
21
21


141
4
21
21


144
4
21
21


146
4
21
21


140
4
21
21


146
4
21
21


113




89
3
20
20


112
3
26
26


89
3
20
20


99
3
23
23


66
3
15
15


101
3
23
23


102
3
23
23


105
3
24
24


96
3
22
22


90
3
21
21


96
3
22
22


103
3
24
24


94
3
21
21


101
3
23
23


97
3
22
22


104
3
24
24


84
3
19
19


68
3
15
15


100
3
23
23


105
3
24
24


98
3
22
22


88
3
20
20


89
3
20
20


100
3
23
23


94
3
21
21


79
3
18
18


94
3
21
21


99
3
23
23


80
3
18
18


107
3
24
24


98
3
22
22


89
3
20
20


94
3
21
21


83
3
19
19


97
3
22
22


91
3
21
21


85
3
19
19


106
3
24
24


101
3
23
23


89
3
20
20


93
3
21
21


100
3
23
23


105
3
24
24


114
3
26
26


68
3
15
15


91
3
21
21


109
3
25
25


95
3
22
22


82
3
19
19


95
3
22
22


96
3
22
22


82
3
19
19


92
3
21
21


101
3
23
23


105
3
24
24


105
3
24
24


95
3
22
22


87
3
20
20


110
3
25
25


87
3
20
20


87
3
20
20


100
3
23
23


106
3
24
24


77
3
17
17


99
3
23
23


74
3
17
17


76
3
17
17


92
3
21
21


85
3
19
19


108
3
25
25


93
3
21
21


71
3
16
16


88
3
20
20


78
3
18
18


87
3
20
20



KeyboardInterrupt: 

In [22]:
print(len(list(cleaned_old.keys())))
print(len(cluster_ids))

3498
3498


In [23]:
print(sum([len(value['y']) for (key, value) in cleaned_old.items()]))

470175


In [25]:
train_output = {}
test_output = {} 
train_output['user_data'], test_output['user_data'], train_output['cluster_ids'] = split(cleaned_old, cluster_ids)
train_output['users'] = list(train_output['user_data'].keys())

test_output['cluster_ids'] = train_output['cluster_ids']
test_output['users'] = train_output['users']

In [27]:
cluster_clients = {}
for i in range(3):
    cluster_clients[i] = []
for index, client in enumerate(train_output['user_data'].keys()):
    cluster_clients[cluster_ids[client]].append(client)

cluster_data = {cluster: np.concatenate([train_output['user_data'][client]['y'] for client in clients]) for (cluster, clients) in cluster_clients.items()}

for cluster in range(3):
    hist, _ = np.histogram(cluster_data[cluster], bins=range(11))
    print(hist)

[35302 39171 35032 36100  1012   880  1018  1100   984  1207]
[ 1205  1319  1192  1189 34291 32067 34965  1276  1171  1487]
[ 1199  1338  1273  1222  1227  1249  1200 36666 34677 34914]


In [24]:
with open('../../.././data/femnist-data-clustered-alt-relaxed/train/data.json', 'w') as file: 
    json.dump(train_output, file)
with open('../../.././data/femnist-data-clustered-alt-relaxed/test/data.json', 'w') as file: 
    json.dump(test_output, file)

432939.1

In [38]:
num_cluster_data = (35302 +39171 +35032 +36100)
print(num_cluster_data)
num_digits = 4
math.floor(num_cluster_data / num_digits * 0.1)

145605


3640

In [36]:
1012  + 880 + 1018 + 1100 +  984 + 1207

6201

8264.333333333334