In [1]:
#%matplotlib inline
import os
import random
import numpy as np
from collections import defaultdict as ddict
from tqdm import tqdm
import pickle
import random

In [2]:
# Hyperparameters
data_name = "WN18RR"
num_client = 5
file_path = "./data/" + data_name + "/"
folder_name = "Fed_data/"
file_name = folder_name + data_name + "-Fed" + str(num_client) + ".pkl"

In [3]:
def load_data(file_path):

    print("load data from {}".format(file_path))

    with open(os.path.join(file_path, 'entities.dict')) as f:
        entity2id = dict()

        for line in f:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(file_path, 'relations.dict')) as f:
        relation2id = dict()

        for line in f:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    train_triplets = read_triplets(os.path.join(file_path, 'train.txt'), entity2id, relation2id)
    valid_triplets = read_triplets(os.path.join(file_path, 'valid.txt'), entity2id, relation2id)
    test_triplets = read_triplets(os.path.join(file_path, 'test.txt'), entity2id, relation2id)

    print('num_entity: {}'.format(len(entity2id)))
    print('num_relation: {}'.format(len(relation2id)))
    print('num_train_triples: {}'.format(len(train_triplets)))
    print('num_valid_triples: {}'.format(len(valid_triplets)))
    print('num_test_triples: {}'.format(len(test_triplets)))

    return entity2id, relation2id, train_triplets, valid_triplets, test_triplets

def read_triplets(file_path, entity2id, relation2id):
    triplets = []

    with open(file_path) as f:
        for line in f:
            head, relation, tail = line.strip().split('\t')
            triplets.append((entity2id[head], relation2id[relation], entity2id[tail]))

    return np.array(triplets)

In [4]:
entity2id, relation2id, train_triplets, valid_triplets, test_triplets = load_data(file_path)

load data from ./data/WN18RR/
num_entity: 40943
num_relation: 11
num_train_triples: 86835
num_valid_triples: 3034
num_test_triples: 3134


In [5]:
random.seed(12345)
# concat all tripes 
triples = np.concatenate((train_triplets, valid_triplets), axis = 0)
triples = np.concatenate((triples, test_triplets), axis = 0)

# shuffle the tripes
np.random.shuffle(triples)

In [6]:
# random split triples into client
# client_tripes = [[] for i in range(num_client)]
client_triples = np.array_split(triples, num_client)
for idx, val in enumerate(client_triples):
    client_triples[idx] = client_triples[idx].tolist()

In [7]:
# split train/valid/test in client
client_data = []

for client_idx in tqdm(range(num_client)):
    all_triples = client_triples[client_idx]

    triples_reidx = []
    ent_reidx = dict()
    rel_reidx = dict()
    entidx = 0
    relidx = 0

    ent_freq = ddict(int)
    rel_freq = ddict(int)

    for tri in all_triples:
        h, r, t = tri
        ent_freq[h] += 1
        ent_freq[t] += 1
        rel_freq[r] += 1
        if h not in ent_reidx.keys():
            ent_reidx[h] = entidx
            entidx += 1
        if t not in ent_reidx.keys():
            ent_reidx[t] = entidx
            entidx += 1
        if r not in rel_reidx.keys():
            rel_reidx[r] = relidx
            relidx += 1
        triples_reidx.append([h, r, t, ent_reidx[h], rel_reidx[r], ent_reidx[t]])

    client_train_triples = []
    client_valid_triples = []
    client_test_triples = []

    random.shuffle(triples_reidx)
    for idx, tri in enumerate(triples_reidx):
        h, r, t, _, _, _ = tri
        if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
            client_test_triples.append(tri)
            ent_freq[h] -= 1
            ent_freq[t] -= 1
            rel_freq[r] -= 1
        else:
            client_train_triples.append(tri)
        if len(client_test_triples) > int(len(triples_reidx) * 0.2):
            break
    client_train_triples.extend(triples_reidx[idx+1:])

    random.shuffle(client_test_triples)
    test_len = len(client_test_triples)
    client_valid_triples = client_test_triples[:int(test_len/2)]
    client_test_triples = client_test_triples[int(test_len/2):] 

    train_edge_index_ori = np.array(client_train_triples)[:, [0, 2]].T
    train_edge_type_ori = np.array(client_train_triples)[:, 1].T
    train_edge_index = np.array(client_train_triples)[:, [3, 5]].T
    train_edge_type = np.array(client_train_triples)[:, 4].T

    valid_edge_index_ori = np.array(client_valid_triples)[:, [0, 2]].T
    valid_edge_type_ori = np.array(client_valid_triples)[:, 1].T
    valid_edge_index = np.array(client_valid_triples)[:, [3, 5]].T
    valid_edge_type = np.array(client_valid_triples)[:, 4].T

    test_edge_index_ori = np.array(client_test_triples)[:, [0, 2]].T
    test_edge_type_ori = np.array(client_test_triples)[:, 1].T
    test_edge_index = np.array(client_test_triples)[:, [3, 5]].T
    test_edge_type = np.array(client_test_triples)[:, 4].T

    client_data_dict = {'train': {'edge_index': train_edge_index, 'edge_type': train_edge_type, 
                          'edge_index_ori': train_edge_index_ori, 'edge_type_ori': train_edge_type_ori},
                'test': {'edge_index': test_edge_index, 'edge_type': test_edge_type, 
                         'edge_index_ori': test_edge_index_ori, 'edge_type_ori': test_edge_type_ori},
                'valid': {'edge_index': valid_edge_index, 'edge_type': valid_edge_type, 
                      'edge_index_ori': valid_edge_index_ori, 'edge_type_ori': valid_edge_type_ori}}

    client_data.append(client_data_dict)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00,  7.92it/s]


In [8]:
# save dataset
pickle.dump(client_data, open(file_name, 'wb'))

#### Check the statistics of dataset

In [9]:
a = train_triplets[:,0]
a = np.append(a, train_triplets[:,2])
a =  np.unique(a)

b = valid_triplets[:,0]
b = np.append(b, valid_triplets[:,2])
b =  np.unique(b)

c = test_triplets[:,0]
c = np.append(c, test_triplets[:,2])
c =  np.unique(c)

In [10]:
print(len(a), len(b), len(c))

40559 5173 5323


In [11]:
e = train_triplets[:,1]
e =  np.unique(e)

f = valid_triplets[:,1] 
f =  np.unique(f)

g = test_triplets[:,1]
g =  np.unique(g)

print(len(e), len(f), len(g))

11 11 11


In [12]:
h = np.setdiff1d(b, a)
h = np.append(h,np.setdiff1d(c, a))
h = np.unique(h)
len(h) + len(a)

40943