In [2]:
import pandas as pd
import numpy  as np

In [3]:
from template_dataset import Dataset

In [4]:
from nltk.corpus   import wordnet
from sklearn.utils import shuffle

In [5]:
import copy
import tqdm
import pickle

In [6]:
train = pd.read_csv('/users/iris/rsourty/datasets/wn18rr/train.txt', sep = '\t', header = None)
valid = pd.read_csv('/users/iris/rsourty/datasets/wn18rr/valid.txt', sep = '\t', header = None)
test  = pd.read_csv( '/users/iris/rsourty/datasets/wn18rr/test.txt' , sep = '\t', header = None)

for dataset in [train, valid, test]:
    
    dataset.columns = ['head', 'relation', 'tail']

In [7]:
test.head()

Unnamed: 0,head,relation,tail
0,6845599,_member_of_domain_usage,3754979
1,789448,_verb_group,1062739
2,8860123,_member_of_domain_region,5688486
3,2233096,_member_meronym,2233338
4,1371092,_hypernym,1352059


In [8]:
def retrieve_words(id_token):
    """
    Retrieve words from ids.
    ADJ, ADJ_SAT, ADV, NOUN, VERB = "a", "s", "r", "n", "v"
    """
    for pos in ['n', 'a', 's', 'r', 'v']:
        try:
            # Original, better
            return str(wordnet.synset_from_pos_and_offset(pos, id_token)).split("'")[1]
        except:
            pass

In [9]:
for dataset in tqdm.tqdm([train, valid, test], position = 0):
    
    for column in ['head', 'tail']:
        
        dataset[column] = dataset[column].apply(lambda x: retrieve_words(x))

100%|██████████| 3/3 [00:16<00:00,  7.84s/it]


In [10]:
train = train[(train['head'] != '') & (train['head'] != 'nan')]
train = train[(train['tail'] != '') & (train['tail'] != 'nan')]

valid = valid[valid['head'] != '']
valid = valid[valid['tail'] != '']

In [11]:
def get_token_idx(train, valid, test):
    vocabulary = pd.Series()
    
    for series in [train, valid, test]:
    
        for column in ['head', 'tail']:
    
            vocabulary = pd.concat([vocabulary, series[column]])
        
    vocabulary = pd.DataFrame(vocabulary).drop_duplicates().reset_index(drop = True)
    
    idx_to_token = vocabulary.to_dict()[0]
    
    token_to_idx = {xi: i for i, xi in idx_to_token.items()}
    
    return idx_to_token, token_to_idx

In [12]:
def get_relation_idx(train, valid, test):
    """
    relation_to_idx: dict[original id, sub id]
    idx_to_relation: dict[sub id, original id]
    """
    relation = pd.Series()
    
    for series in [train, valid, test]:
    
        relation = pd.concat([relation, series['relation']])
    
    relation = pd.DataFrame(relation).drop_duplicates().reset_index(drop = True)
    
    idx_to_relation = relation.to_dict()[0]
    
    relation_to_idx = {xi: i for i, xi in idx_to_relation.items()}
    
    return idx_to_relation, relation_to_idx

In [13]:
idx_to_token, token_to_idx = get_token_idx(train, valid, test)

In [14]:
idx_to_relation, relation_to_idx = get_relation_idx(train, valid, test)

In [15]:
for dataset in [train, valid, test]:
    
    dataset['relation'] = dataset['relation'].apply(lambda x: relation_to_idx[x])
    
    for column in ['head', 'tail']:
        
        dataset[column] = dataset[column].apply(lambda x: token_to_idx[x])

In [16]:
test.head()

Unnamed: 0,head,relation,tail
0,878,7,5374
1,11738,9,29937
2,405,8,16604
3,10317,4,33804
4,36120,0,31829


In [17]:
n_entity   = len(idx_to_token)
n_relation = len(relation_to_idx)

In [18]:
np.random.seed(42)
train = shuffle(train)

In [19]:
train_teacher_1, train_teacher_2, train_teacher_3 = np.array_split(train, 3)

In [20]:
def get_list_train_entities(train):
    list_training_entities = set()
    for index, (w1, _, w2) in train.iterrows():
        list_training_entities.add(w1)
        list_training_entities.add(w2)
    return list(list_training_entities)

In [21]:
list_entities_train_teacher_1 = get_list_train_entities(train_teacher_1)
list_entities_train_teacher_2 = get_list_train_entities(train_teacher_2)
list_entities_train_teacher_3 = get_list_train_entities(train_teacher_3)
list_entities_train_teacher = get_list_train_entities(train)

In [21]:
dataset_teacher_1 = Dataset(
    train           = train_teacher_1, 
    valid           = valid, 
    test            = test, 
    token_to_idx    = token_to_idx, 
    idx_to_token    = idx_to_token, 
    relation_to_idx = relation_to_idx, 
    idx_to_relation = idx_to_relation,
    list_entities_train = list_entities_train_teacher_1,
    
)

In [22]:
dataset_teacher_2 = Dataset(
    train           = train_teacher_2, 
    valid           = valid, 
    test            = test, 
    token_to_idx    = token_to_idx, 
    idx_to_token    = idx_to_token, 
    relation_to_idx = relation_to_idx, 
    idx_to_relation = idx_to_relation,
    list_entities_train = list_entities_train_teacher_2,
)

In [23]:
dataset_teacher_3 = Dataset(
    train           = train_teacher_3, 
    valid           = valid, 
    test            = test, 
    token_to_idx    = token_to_idx, 
    idx_to_token    = idx_to_token, 
    relation_to_idx = relation_to_idx, 
    idx_to_relation = idx_to_relation,
    list_entities_train = list_entities_train_teacher_3,
)

In [24]:
with open(f'/users/iris/rsourty/experiments/distillation/datasets/teacher_1.pickle', 'wb') as handle:
    pickle.dump(dataset_teacher_1, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [25]:
with open(f'/users/iris/rsourty/experiments/distillation/datasets/teacher_2.pickle', 'wb') as handle:
    pickle.dump(dataset_teacher_2, handle, protocol = pickle.HIGHEST_PROTOCOL)       

In [26]:
with open(f'/users/iris/rsourty/experiments/distillation/datasets/teacher_3.pickle', 'wb') as handle:
    pickle.dump(dataset_teacher_3, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [27]:
wordnet_18_rr = Dataset(
    train = train, 
    valid = valid, 
    test  = test, 
    token_to_idx = token_to_idx, 
    idx_to_token = idx_to_token, 
    relation_to_idx = relation_to_idx, 
    idx_to_relation = idx_to_relation,
    list_entities_train = list_entities_train_teacher,
)

In [28]:
with open(f'/users/iris/rsourty/experiments/distillation/datasets/100_wn_18_rr.pickle', 'wb') as handle:
    pickle.dump(wordnet_18_rr, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [234]:
test.shape

(3134, 3)

In [22]:
for percent in tqdm.tqdm([0.20, 0.25, 0.33], position = 0):
    sample = train.sample(frac = percent)
    list_entities_sample = get_list_train_entities(sample)
    dataset = Dataset(
        train           = sample, 
        valid           = valid, 
        test            = test, 
        token_to_idx    = token_to_idx, 
        idx_to_token    = idx_to_token, 
        relation_to_idx = relation_to_idx, 
        idx_to_relation = idx_to_relation,
        list_entities_train = list_entities_sample,

    )
    with open(f'/users/iris/rsourty/experiments/distillation/datasets/wn18rr_{percent}_percent.pickle', 'wb') as handle:
        pickle.dump(dataset, handle, protocol = pickle.HIGHEST_PROTOCOL)    

100%|██████████| 3/3 [00:03<00:00,  1.09s/it]
