In [1]:
import pandas as pd
import os
import shutil
from torchkge.utils.operations import get_dictionaries

import utils

import pickle

In [2]:
def split_df(df, train_ratio=0.95):
    train_len = int(len(df) * train_ratio)
    
    df = df.sample(frac=1).reset_index(drop=True)
    
    train_df = df.loc[:train_len]
    test_df = df.loc[train_len:]
    
    return train_df, test_df    

In [3]:
def split(path):
    print('path to dataset: ', path)
    
    intra_df1 = pd.read_csv(path + 'train1.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_1 = get_dictionaries(intra_df1, ent=True)
    rel2ix_1 = get_dictionaries(intra_df1, ent=False)
    n1 = len(ent2ix_1)
    ent_set_1 = set(ent2ix_1.keys())
    
    
    
    intra_df2 = pd.read_csv(path + 'train2.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_2 = get_dictionaries(intra_df2, ent=True)
    rel2ix_2 = get_dictionaries(intra_df2, ent=False)
    n2 = len(ent2ix_2)
    ent_set_2 = set(ent2ix_2.keys())
    
    common_df = pd.read_csv(path + 'train_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_common = get_dictionaries(common_df, ent=True)
    rel2ix_common = get_dictionaries(common_df, ent=False)
    n_common = len(ent2ix_common)
    ent_set_common = set(ent2ix_common.keys())
    
    
    cross_1_common = pd.read_csv(path + 'cross_1_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    cross_2_common = pd.read_csv(path + 'cross_2_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    
    # merge ent2ix from all df
    ent2ix_1_common = {k: v + n1 for k, v in ent2ix_common.items()}
    ent2ix_1.update(ent2ix_1_common)
    ent2ix_2 = {k: v + n1 + n_common for k, v in ent2ix_2.items()}
    ent2ix_1.update(ent2ix_2)
    ent2ix = ent2ix_1

    print('total entities: ', n1 + n_common + n2, len(ent2ix))

    # merge rel2ix from all df
    dif_rel = set(rel2ix_common.keys()).difference(set(rel2ix_1.keys()))
    for rel in dif_rel:
        rel_id = len(rel2ix_1)
        rel2ix_1[rel] = rel_id
    dif_rel = set(rel2ix_2.keys()).difference(set(rel2ix_1.keys()))
    for rel in dif_rel:
        rel_id = len(rel2ix_1)
        rel2ix_1[rel] = rel_id

    rel2ix = rel2ix_1
    
    intra_df = pd.concat([intra_df1, cross_1_common, common_df, cross_2_common, intra_df2], axis=0)
    
    kg = utils.Extended_KnowledgeGraph(df=intra_df, ent2ix=ent2ix, rel2ix=rel2ix)
    
    kg, intra_kg_test = kg.split_kg(share=0.95)    
    
    # inter-domain
    inter_df = pd.read_csv(path + 'cross_12.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    
    inter_kg = utils.Extended_KnowledgeGraph(df=inter_df, ent2ix=kg.ent2ix, rel2ix=kg.rel2ix)
    
    print('\t\t len INTER df test: ', len(inter_df))
    
    kg_dict = {'intra_train': kg,
               'intra_test': intra_kg_test,
               'inter_test': inter_kg,
               'n1': n1,
               'n2': n2,
               'n_common': n_common,
              }
    
    # saving
    save_path = path + 'fix_intra-test/'
    if os.path.exists(save_path) and os.path.exists(save_path):
        shutil.rmtree(save_path)
    if not os.path.exists(save_path):
        os.makedirs(save_path)    
        
    pickle.dump(kg_dict, open(save_path + 'kg_dict.pkl', 'bw'))
    
    print('Done!!!')

# split FB15k-237

In [5]:
FB15k_15_path = '../data/FB15k-237/semi/divided/'
split(FB15k_15_path)

path to dataset:  ../data/FB15k-237/semi/divided/
total entities:  5293 5293
		 len INTER df test:  22879
Done!!!


In [6]:
FB15k_3_path = '../data/FB15k-237/semi_3/divided/'
split(FB15k_3_path)

path to dataset:  ../data/FB15k-237/semi_3/divided/
total entities:  5435 5435
		 len INTER df test:  22006
Done!!!


In [7]:
FB15k_5_path = '../data/FB15k-237/semi_5/divided/'
split(FB15k_5_path)

path to dataset:  ../data/FB15k-237/semi_5/divided/
total entities:  5399 5399
		 len INTER df test:  22505
Done!!!


# Split WN18RR

In [8]:
WN18RR_15_path = '../data/WN18RR/semi/divided/'
split(WN18RR_15_path)

path to dataset:  ../data/WN18RR/semi/divided/
total entities:  11065 11065
		 len INTER df test:  3884
Done!!!


In [9]:
WN18RR_3_path = '../data/WN18RR/semi_3/divided/'
split(WN18RR_3_path)

path to dataset:  ../data/WN18RR/semi_3/divided/
total entities:  5608 5608
		 len INTER df test:  1218
Done!!!


In [10]:
WN18RR_5_path = '../data/WN18RR/semi_5/divided/'
split(WN18RR_5_path)

path to dataset:  ../data/WN18RR/semi_5/divided/
total entities:  6012 6012
		 len INTER df test:  1444
Done!!!


# Split DBbook2014

In [11]:
DBbook_15 = '../data/KG_datasets/dbbook2014/semi/divided/'
split(DBbook_15)

path to dataset:  ../data/KG_datasets/dbbook2014/semi/divided/
total entities:  5842 5842
		 len INTER df test:  26881
Done!!!


In [12]:
DBbook_3 = '../data/KG_datasets/dbbook2014/semi_3/divided/'
split(DBbook_3)

path to dataset:  ../data/KG_datasets/dbbook2014/semi_3/divided/
total entities:  5914 5914
		 len INTER df test:  33410
Done!!!


In [13]:
DBbook_5 = '../data/KG_datasets/dbbook2014/semi_5/divided/'
split(DBbook_5)

path to dataset:  ../data/KG_datasets/dbbook2014/semi_5/divided/
total entities:  5965 5965
		 len INTER df test:  29201
Done!!!


# Split ML1M

In [14]:
ml1m_15 = '../data/KG_datasets/ml1m/semi/divided/'
split(ml1m_15)

path to dataset:  ../data/KG_datasets/ml1m/semi/divided/
total entities:  5488 5488
		 len INTER df test:  36369
Done!!!


In [15]:
ml1m_3 = '../data/KG_datasets/ml1m/semi_3/divided/'
split(ml1m_3)

path to dataset:  ../data/KG_datasets/ml1m/semi_3/divided/
total entities:  5561 5561
		 len INTER df test:  33616
Done!!!


In [16]:
ml1m_5 = '../data/KG_datasets/ml1m/semi_5/divided/'
split(ml1m_5)

path to dataset:  ../data/KG_datasets/ml1m/semi_5/divided/
total entities:  5697 5697
		 len INTER df test:  38727
Done!!!


# Splitting for unoverlapped data

In [48]:
def unoverlapped_split(path):
    print('path to dataset: ', path)
    
    intra_df1 = pd.read_csv(path + 'train1.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    intra_df2 = pd.read_csv(path + 'train2.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
        
    
    intra_df1_train, intra_df1_test = split_df(intra_df1)
    intra_df2_train, intra_df2_test = split_df(intra_df2)
        
    
    # merge intra_test
    intra_df_test = pd.concat([intra_df1_test, intra_df2_test], axis=0)
    
    # saving
    save_path = path + 'fix_intra-test/'
    if os.path.exists(save_path) and os.path.exists(save_path):
        shutil.rmtree(save_path)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    
    # save intra_test
    intra_df_test.to_csv(save_path + 'intra_test.csv', index=False)
    
    # save intra train
    intra_df1_train.to_csv(save_path + 'intra_train1.csv', index=False)
    intra_df2_train.to_csv(save_path + 'intra_train2.csv', index=False)        
    
    # inter-domain
    inter_df12 = pd.read_csv(path + 'cross_h1t2.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    inter_df21 = pd.read_csv(path + 'cross_h2t1.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    inter_df = pd.concat([inter_df12, inter_df21], axis=0)
    
    print('\t\t len INTER df test: ', len(inter_df))
    
    # save
    inter_df.to_csv(save_path + 'inter_test.csv', index=False)
    
    print('Done!!!')

In [None]:
def unoverlapped_split(path):
    print('path to dataset: ', path)
    
    intra_df1 = pd.read_csv(path + 'train1.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_1 = get_dictionaries(intra_df1, ent=True)
    rel2ix_1 = get_dictionaries(intra_df1, ent=False)
    n1 = len(ent2ix_1)
    ent_set_1 = set(ent2ix_1.keys())
     
    intra_df2 = pd.read_csv(path + 'train2.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_2 = get_dictionaries(intra_df2, ent=True)
    rel2ix_2 = get_dictionaries(intra_df2, ent=False)
    n2 = len(ent2ix_2)
    ent_set_2 = set(ent2ix_2.keys())
    
    #common_df = pd.read_csv(path + 'train_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    ent2ix_common = {}
    rel2ix_common = {}
    n_common = len(ent2ix_common)
    ent_set_common = set(ent2ix_common.keys())
    
    #cross_1_common = pd.read_csv(path + 'cross_1_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    #cross_2_common = pd.read_csv(path + 'cross_2_common.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    
    # merge ent2ix from all df
    ent2ix_1_common = {k: v + n1 for k, v in ent2ix_common.items()}
    ent2ix_1.update(ent2ix_1_common)
    ent2ix_2 = {k: v + n1 + n_common for k, v in ent2ix_2.items()}
    ent2ix_1.update(ent2ix_2)
    ent2ix = ent2ix_1

    print('total entities: ', n1 + n_common + n2, len(ent2ix))

    # merge rel2ix from all df
    dif_rel = set(rel2ix_common.keys()).difference(set(rel2ix_1.keys()))
    for rel in dif_rel:
        rel_id = len(rel2ix_1)
        rel2ix_1[rel] = rel_id
    dif_rel = set(rel2ix_2.keys()).difference(set(rel2ix_1.keys()))
    for rel in dif_rel:
        rel_id = len(rel2ix_1)
        rel2ix_1[rel] = rel_id

    rel2ix = rel2ix_1
    
    intra_df = pd.concat([intra_df1, intra_df2], axis=0)
    
    kg = utils.Extended_KnowledgeGraph(df=intra_df, ent2ix=ent2ix, rel2ix=rel2ix)
    
    kg, intra_kg_test = kg.split_kg(share=0.95)    
    
    # inter-domain
    inter_df_12 = pd.read_csv(path + 'cross_h1t2.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    inter_df_21 = pd.read_csv(path + 'cross_h2t1.csv', delimiter='\t').rename(columns={'head':'from', 'rel':'rel', 'tail':'to'})
    inter_df = pd.concat([inter_df_12, inter_df_21], axis=0)
    
    inter_kg = utils.Extended_KnowledgeGraph(df=inter_df, ent2ix=kg.ent2ix, rel2ix=kg.rel2ix)
    
    print('\t\t len INTER df test: ', len(inter_df))
    
    kg_dict = {'intra_train': kg,
               'intra_test': intra_kg_test,
               'inter_test': inter_kg,
               'n1': n1,
               'n2': n2,
               'n_common': n_common,
              }
    
    # saving
    save_path = path + 'fix_intra-test/'
    if os.path.exists(save_path) and os.path.exists(save_path):
        shutil.rmtree(save_path)
    if not os.path.exists(save_path):
        os.makedirs(save_path)    
        
    pickle.dump(kg_dict, open(save_path + 'kg_dict.pkl', 'bw'))
    
    print('Done!!!')

### split FB15k-237

In [8]:
FB15k_path = '../data/FB15k-237/divided/'
split(FB15k_path)

path to dataset:  ../data/FB15k-237/divided/


FileNotFoundError: [Errno 2] No such file or directory: '../data/FB15k-237/divided/train_common.csv'

In [7]:
ls ../

Done_Nickel_ICML2011-A Three-Way Model for Collective Learning on Multi-Relational Data_done.pdf
ICDM07-ASALSAN.pdf
Kemp-etal-AAAI06.pdf
README.md
[0m[01;35mRescal.jpg[0m
__init__.py
[01;34m__pycache__[0m/
[01;34mdata[0m/
[01;34mrescal_als[0m/
[01;34mrescal_torch[0m/
[01;34mresults[0m/
[01;34mwasserstein_ot[0m/


In [51]:
WN18RR_path = 'data/WN18RR/divided/'
unoverlapped_split(WN18RR_path)

path to dataset:  data/WN18RR/divided/
		 len INTER df test:  3889
Done!!!


In [52]:
DBbook_path = 'data/KG_datasets/dbbook2014/kg/divided/'
unoverlapped_split(DBbook_path)

path to dataset:  data/KG_datasets/dbbook2014/kg/divided/
		 len INTER df test:  32141
Done!!!


In [53]:
ML1M_path = 'data/KG_datasets/ml1m/kg/divided/'
unoverlapped_split(ML1M_path)

path to dataset:  data/KG_datasets/ml1m/kg/divided/
		 len INTER df test:  23745
Done!!!
