In [1]:
import sys
sys.path.insert(0,'..')
from torchkge.utils.datasets import load_wikidatasets

kg = load_wikidatasets(which='countries', limit_= 0, data_home='../raw_data')
kg_train, kg_valid, kg_test = kg.split_kg(validation=True)

In [2]:
print(f'# of entities : {kg_train.n_ent}')
print(f'# of relations: {kg_train.n_rel}')

# of entities : 28777
# of relations: 166


In [3]:
print(f'# of triplets in train : {kg_train.n_facts}')
print(f'# of triplets in valid : {kg_valid.n_facts}')
print(f'# of triplets in test  : {kg_test.n_facts}')

# of triplets in train : 9227
# of triplets in valid : 1060
# of triplets in test  : 1074


In [4]:
list(kg_train.ent2ix.items())[:5]

[('Belgium', 0),
 ('Portugal', 1),
 ("People's Republic of China", 2),
 ('Brazil', 3),
 ('Germany', 4)]

In [5]:
list(kg_train.ix2ent.items())[:5]

[(0, 'Belgium'),
 (1, 'Portugal'),
 (2, "People's Republic of China"),
 (3, 'Brazil'),
 (4, 'Germany')]

In [6]:
list(kg_train.rel2ix.items())[:5]

[('participant of', 0),
 ("topic's main Wikimedia portal", 1),
 ('motto', 2),
 ('Wikimedia outline', 3),
 ('currency', 4)]

In [7]:
list(kg_train.ix2rel.items())[:5]

[(0, 'participant of'),
 (1, "topic's main Wikimedia portal"),
 (2, 'motto'),
 (3, 'Wikimedia outline'),
 (4, 'currency')]

In [8]:
list(zip(kg_train.head_idx, kg_train.relations, kg_train.tail_idx))[:5]

[(tensor(0), tensor(16), tensor(3047)),
 (tensor(0), tensor(16), tensor(424)),
 (tensor(0), tensor(16), tensor(4)),
 (tensor(0), tensor(16), tensor(1379)),
 (tensor(0), tensor(16), tensor(1380))]

In [9]:
from random import randint
test_size = 20

for _ in range(test_size):
    i = randint(0, kg_train.n_facts)
    print(kg_train.ix2ent[kg_train.head_idx[i].item()], '->', kg_train.ix2rel[kg_train.relations[i].item()], '->', kg_train.ix2ent[kg_train.tail_idx[i].item()])

Algeria -> shares border with -> Morocco
Brazil -> diplomatic relation -> France
Comoros -> shares border with -> Mozambique
Pakistan -> diplomatic relation -> Malta
Tunisia -> diplomatic relation -> United States of America
Kingdom of Aragon -> replaces -> County of Aragon
Saudi Arabia -> diplomatic relation -> Russia
Almohad Caliphate -> replaced by -> Kingdom of Portugal
Montenegro -> diplomatic relation -> Kosovo
United States of America -> diplomatic relation -> Indonesia
South Russia -> replaces -> Mountainous Republic of the Northern Caucasus
Switzerland -> shares border with -> Germany
League  Federal -> country -> Argentina
Indonesia -> diplomatic relation -> People's Republic of China
People's Republic of China -> diplomatic relation -> India
Assam -> shares border with -> Manipur
Liechtenstein -> shares border with -> Nazi Germany
Ottoman Empire -> shares border with -> Austria-Hungary
Kingdom of Holland -> followed by -> First French Empire
Australia -> diplomatic relation 

In [10]:
# kg_train, kg_valid, kg_test = kg.split_kg(share=0.8, validation=True)

In [11]:
# from os.path import exists
# import pickle

# if not exists('../data/WikiDataSets/countries/WikiData_train.pkl'):
#     with open('../data/WikiDataSets/countries/WikiData_train.pkl', mode='wb') as io:
#         pickle.dump(kg_train, io)

In [12]:
import torch
from torch import cuda
from torch.optim import Adam
from tqdm.autonotebook import tqdm

from torchkge.models import TransEModel, DistMultModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader

In [13]:
print(f'# of triples in train: {kg_train.n_facts}')
print(f'# of triples in valid: {kg_valid.n_facts}')
print(f'# of triples in test : {kg_test.n_facts}')

# of triples in train: 9227
# of triples in valid: 1060
# of triples in test : 1074


In [14]:
# training config

ent_emb_dim = 20
lr = 0.0004
b_size = 500
margin = 0.5
summary_step = 50

In [15]:
# define the model and criterion

# model = TransEModel(ent_emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2')
model = DistMultModel(ent_emb_dim, kg_train.n_ent, kg_train.n_rel)
criterion = MarginLoss(margin)

In [16]:
# use cuda if it is available

if cuda.is_available():
    cuda.empty_cache()
    model.cuda()
    criterion.cuda()  

In [17]:
# define the optimizer and dataloader, sampler

optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
sampler = BernoulliNegativeSampler(kg_train)
tr_dl = DataLoader(kg_train, batch_size = b_size, use_cuda='all')
val_dl = DataLoader(kg_valid, batch_size = b_size, use_cuda = 'all')

In [18]:
# training

n_epochs = 100
best_val_loss = 1e+10
iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    
    tr_loss = 0
    model.train()
    
    for step, batch in enumerate(tr_dl):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r) # negative head, negative tail
        
        optimizer.zero_grad()
        
        pos, neg = model(h, t, n_h, n_t, r)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    tr_loss /= (step+1)
    
    if model.training:
        model.eval()
    val_loss = 0
    # for step, batch in tqdm(enumerate(val_dl), desc='steps', total=len(val_dl)):
    for step, batch in enumerate(val_dl):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)
        with torch.no_grad():
            pos, neg = model(h, t, n_h, n_t, r)
            loss = criterion(pos, neg)
            val_loss += loss.item()
    val_loss /= (step+1)
    # iterator.set_description('Epoch {} | mean loss: {:.5f}, valid loss: {:.5f}'.format(epoch+1, tr_loss, val_loss))
    if epoch % 50 == 0:
        tqdm.write('Epoch {} | mean loss: {:.5f}, valid loss: {:.5f}'.format(epoch+1, tr_loss, val_loss))
    model.normalize_parameters()
    
    is_best = val_loss < best_val_loss
#     if is_best:
#         state = {'epoch': epoch, 
#                  'state_dict': model.state_dict(), 
#                  'optimizer': optimizer.state_dict()}
#         torch.save(state, '../experiment/wiki_country/best_transe.tar')
#         best_val_loss = val_loss

  0%|          | 0/100 [00:00<?, ?epoch/s]

Epoch 1 | mean loss: 242.79393, valid loss: 176.97351
Epoch 51 | mean loss: 180.95449, valid loss: 137.00548


In [19]:
from torchkge.evaluation import LinkPredictionEvaluator

In [20]:
evaluator = LinkPredictionEvaluator(model, kg_test)

In [21]:
result = evaluator.evaluate(b_size=32)

Link prediction evaluation:   0%|          | 0/34 [00:00<?, ?batch/s]

In [22]:
evaluator.print_results()

Hit@10 : 0.1099 		 Filt. Hit@10 : 0.1988
Mean Rank : 2467 	 Filt. Mean Rank : 2446
MRR : 0.0724 		 Filt. MRR : 0.1006


In [23]:
result

{'Normal': {'Hit@10': 0.1099, 'Mean Rank': 2467, 'MRR': 0.0724},
 'Filtered': {'Hit@10': 0.1988, 'Mean Rank': 2446, 'MRR': 0.1006}}

In [24]:
from torchkge.evaluation import TripletClassificationEvaluator

In [27]:
tc_evaluator = TripletClassificationEvaluator(model, kg_valid, kg_test)

In [28]:
tc_evaluator.evaluate(b_size = 128)

In [29]:
tc_evaluator.accuracy(b_size=128)

0.5283985102420856