In [2]:
import json
import os
from collections import defaultdict
import numpy as np

In [3]:
def read_dir(data_dir):
    clients = []
    cluster_ids = {}
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'cluster_ids' in cdata:
            for idx, u in enumerate(cdata['users']):
                cluster_ids[u] = cdata['cluster_ids'][idx]
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    # If there are no cluser_ids in the data, assign 0 for each user
    cluster_ids = [cluster_ids[c] if c in cluster_ids else 0 for c in clients]
    return clients, cluster_ids, groups, data

def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories

    assumes:
    - the data in the input directories are .json files with
        keys 'users', 'user_data' and 'cluster_ids'
    - the set of train set users is the same as the set of test set users

    Returns:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_cluster_ids, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_cluster_ids, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups
    assert train_cluster_ids == test_cluster_ids

    # Todo return groups if required
    return train_clients, train_cluster_ids, train_data, test_data

In [5]:
data_dir = '../../.././data/cifar100/regular/'
train_clients, train_cluster_ids, train_data, test_data = read_data(f'{data_dir}train/', f'{data_dir}test/')

In [21]:
len(train_cluster_ids)

500

In [48]:
new_train_data = {}
new_test_data = {}
new_cluster_ids = {}

clusters = np.arange(20)

for cluster_id in clusters:
    client_ids = list(np.array(train_clients)[np.argwhere(np.array(train_cluster_ids) == cluster_id)].flatten())
    while len(client_ids) > 1:
        new_train_data[client_ids[0]] = {'x': train_data[client_ids[0]]['x'] + train_data[client_ids[-1]]['x'],
                                         'y': train_data[client_ids[0]]['y'] + train_data[client_ids[-1]]['y']}
        new_test_data[client_ids[0]] = {'x': test_data[client_ids[0]]['x'] + test_data[client_ids[-1]]['x'],
                                         'y': test_data[client_ids[0]]['y'] + test_data[client_ids[-1]]['y']}
        new_cluster_ids[client_ids[0]] = int(cluster_id)
        del client_ids[0]
        del client_ids[-1]

In [47]:
new_cluster_ids.values()

dict_values([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19])

In [None]:
train_output = {
    'user_data': new_train_data,
    'cluster_ids': list(new_cluster_ids.values()),
    'users': list(new_train_data.keys())
}
test_output = {
    'user_data': new_test_data,
    'cluster_ids': list(new_cluster_ids.values()),
    'users': list(new_train_data.keys())
}

with open('../../.././data/cifar100/fewer_clients_more_data/train/data.json', 'w') as file: 
    json.dump(train_output, file)
with open('../../.././data/cifar100/fewer_clients_more_data/test/data.json', 'w') as file: 
    json.dump(test_output, file)