In [1]:
import time

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import json
# from model import Model
# from femnist.cnn import ClientModel
# from utils.model_utils import read_data
import pydash
import math

pd.set_option('display.max_columns', 500)

In [2]:
from collections import defaultdict
def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
#         for idx, u in enumerate(cdata['users']):
#             cluster_ids[u] = cdata['cluster_ids'][idx]
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
            print("hierarchies exist")
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
#     cluster_ids = [cluster_ids[c] for c in clients]
    return clients, groups, data

In [3]:
users, _, data,  = read_dir('/tf/work/tangle-learning/data/femnist-data/large/train')

In [4]:
_, _, data_test,  = read_dir('/tf/work/tangle-learning/data/femnist-data/large/test')

In [5]:
NUM_CLIENTS = 351

In [6]:
num_clusters = 3
digits_per_cluster = np.array_split(range(10), num_clusters)
users_per_cluster = (NUM_CLIENTS // num_clusters) + 1
print(digits_per_cluster)

[array([0, 1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]


In [7]:
from copy import deepcopy
complete_data = {}
for i in range(len(users)):
    if i < NUM_CLIENTS:
        complete_data[users[i]] = deepcopy(data[users[i]])
        complete_data[users[i]]['x'].extend(data_test[users[i]]['x'])
        complete_data[users[i]]['y'].extend(data_test[users[i]]['y'])
    else:
        complete_data[users[i % NUM_CLIENTS]]['x'].extend(data[users[i]]['x'])
        complete_data[users[i % NUM_CLIENTS]]['y'].extend(data[users[i]]['y'])
        complete_data[users[i % NUM_CLIENTS]]['x'].extend(data_test[users[i]]['x'])
        complete_data[users[i % NUM_CLIENTS]]['y'].extend(data_test[users[i]]['y'])
    
users = users[:NUM_CLIENTS]

In [8]:
def clean(data, cluster_leak=0.1):
    cleaned = {}
    cluster_ids = {}
    for user_index in range(len(complete_data.keys())):
        userdata = {}
        userdata['x'] = []
        userdata['y'] = []
        
        # Add cluster data
        for i in range(num_clusters):
            user_index_2 = (user_index + i * users_per_cluster) % len(complete_data.keys())
            mask = np.array([elem in digits_per_cluster[user_index // users_per_cluster] for elem in data[users[user_index_2]]['y']])
            userdata['x'].extend(np.array(data[users[user_index_2]]['x'])[mask].tolist())
            userdata['y'].extend(np.array(data[users[user_index_2]]['y'])[mask].tolist())
        
        # Add fraction of non-cluster data
        num_cluster_data = len(userdata['y'])
        num_digits_for_user = len(digits_per_cluster[user_index // users_per_cluster])
        num_additional_data = round(num_cluster_data / num_digits_for_user * cluster_leak) * (10 - num_digits_for_user)
        for i in range(num_clusters):
            if len(userdata['y']) >= num_additional_data + num_cluster_data:
                break
            else:
                missing_additional_data = num_additional_data + num_cluster_data - len(userdata['y'])
                user_index_2 = (user_index + i * users_per_cluster) % len(complete_data.keys())
                mask = np.array([elem not in digits_per_cluster[user_index // users_per_cluster] and elem < 10 for elem in data[users[user_index_2]]['y']])
                userdata['x'].extend(np.array(data[users[user_index_2]]['x'])[mask][:missing_additional_data].tolist())
                userdata['y'].extend(np.array(data[users[user_index_2]]['y'])[mask][:missing_additional_data].tolist())
                    
        # Shuffle
        shuffle_mask = np.random.permutation(len(userdata['y']))
        userdata['x'] = np.array(userdata['x'])[shuffle_mask].tolist()
        userdata['y'] = np.array(userdata['y'])[shuffle_mask].tolist()
                    
        assert len(userdata['y']) > 0, 'Not enough data for client {}'.format(users[user_index])
        
        cleaned[users[user_index]] = userdata
        cluster_ids[users[user_index]] = user_index // users_per_cluster
    return cleaned, cluster_ids

In [9]:
def split(cleaned_old, cluster_ids, test_split=0.1):
    cleaned = {}
    cleaned_test = {}
    for username in cluster_ids.keys():
        userdata = cleaned_old[username]
        train_size = math.floor(len(userdata['y']) * (1 - test_split))
        cleaned[username]= {'x': userdata['x'][:train_size],
                            'y': userdata['y'][:train_size]}
        cleaned_test[username]= {'x': userdata['x'][train_size:],
                                 'y': userdata['y'][train_size:]}
    return cleaned, cleaned_test, cluster_ids

In [10]:
cleaned_old, cluster_ids = clean(complete_data)

In [11]:
print(len(list(cleaned_old.keys())))

print('Average # of data points')
print(np.mean([len(x['x']) for k, x in cleaned_old.items()]))

351
Average # of data points
1346.111111111111


In [12]:
print(sum([len(value['y']) for (key, value) in cleaned_old.items()]))

472485


In [15]:
train_output = {}
test_output = {} 
train_output['user_data'], test_output['user_data'], cluster_ids = split(cleaned_old, cluster_ids)
train_output['cluster_ids'] = list(cluster_ids.values())
train_output['users'] = list(cluster_ids.keys())

test_output['cluster_ids'] = train_output['cluster_ids']
test_output['users'] = train_output['users']

In [16]:
with open('../../.././data/femnist-data-clustered-alt-relaxed-small/train/data.json', 'w') as file: 
    json.dump(train_output, file)
with open('../../.././data/femnist-data-clustered-alt-relaxed-small/test/data.json', 'w') as file: 
    json.dump(test_output, file)