In [1]:
import sys
sys.path.insert(0,'..')
from torchkge.utils.datasets import load_wikidatasets

kg_train, kg_valid, kg_test = load_wikidatasets(which='countries', limit_= 0, data_home='../data')

In [3]:
print(f'# of entities : {kg_train.n_ent}')
print(f'# of relations: {kg_train.n_rel}')

# of entities : 28777
# of relations: 166


In [4]:
print(f'# of triplets in train : {kg_train.n_facts}')
print(f'# of triplets in valid : {kg_valid.n_facts}')
print(f'# of triplets in test  : {kg_test.n_facts}')

# of triplets in triple : 9207
# of triplets in valid  : 1078
# of triplets in test   : 1076


In [3]:
list(kg.ent2ix.items())[:5]

[('Belgium', 0),
 ('Portugal', 1),
 ("People's Republic of China", 2),
 ('Brazil', 3),
 ('Germany', 4)]

In [4]:
list(kg.ix2ent.items())[:5]

[(0, 'Belgium'),
 (1, 'Portugal'),
 (2, "People's Republic of China"),
 (3, 'Brazil'),
 (4, 'Germany')]

In [5]:
list(kg.rel2ix.items())[:5]

[('participant of', 0),
 ("topic's main Wikimedia portal", 1),
 ('motto', 2),
 ('Wikimedia outline', 3),
 ('currency', 4)]

In [6]:
list(kg.ix2rel.items())[:5]

[(0, 'participant of'),
 (1, "topic's main Wikimedia portal"),
 (2, 'motto'),
 (3, 'Wikimedia outline'),
 (4, 'currency')]

In [7]:
list(zip(kg.head_idx, kg.relations, kg.tail_idx))[:5]

[(tensor(0), tensor(11), tensor(0)),
 (tensor(0), tensor(16), tensor(3047)),
 (tensor(0), tensor(16), tensor(424)),
 (tensor(0), tensor(16), tensor(4)),
 (tensor(0), tensor(16), tensor(1379))]

In [8]:
from random import randint
test_size = 20

for _ in range(test_size):
    i = randint(0, kg.n_facts)
    print(kg.ix2ent[kg.head_idx[i].item()], '->', kg.ix2rel[kg.relations[i].item()], '->', kg.ix2ent[kg.tail_idx[i].item()])

Israel -> diplomatic relation -> United Arab Emirates
Armenia -> diplomatic relation -> Syria
Israel -> diplomatic relation -> Mexico
Taiwan -> diplomatic relation -> Australia
Northern Song Dynasty -> part of -> Song dynasty
Guyana -> diplomatic relation -> Barbados
Equatorial Guinea -> diplomatic relation -> Brazil
Bolivia -> diplomatic relation -> Peru
Haiti -> diplomatic relation -> Chile
Iraq -> diplomatic relation -> Denmark
Sierra Leone -> diplomatic relation -> Liberia
Taiwan -> diplomatic relation -> Venezuela
Kingdom of France (1791-1792) -> follows -> Kingdom of France
Italy -> diplomatic relation -> Kosovo
Romania -> diplomatic relation -> Russia
Ukraine -> shares border with -> Czechoslovakia
Bishopric of Verdun -> country -> Holy Roman Empire
Portugal -> diplomatic relation -> Bangladesh
Ethiopia -> diplomatic relation -> Hungary
Serbia -> shares border with -> Kosovo


In [9]:
kg_train, kg_valid, kg_test = kg.split_kg(share=0.8, validation=True)

In [13]:
from os.path import exists
import pickle

if not exists('../data/WikiDataSets/countries/WikiData_train.pkl'):
    with open('../data/WikiDataSets/countries/WikiData_train.pkl', mode='wb') as io:
        pickle.dump(kg_train, io)

In [10]:
import torch
from torch import cuda
from torch.optim import Adam
from tqdm.autonotebook import tqdm

from torchkge.models import TransEModel, DistMultModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader

In [11]:
print(f'# of triples in train: {kg_train.n_facts}')
print(f'# of triples in valid: {kg_valid.n_facts}')
print(f'# of triples in test : {kg_test.n_facts}')

# of triples in train: 9207
# of triples in valid: 1074
# of triples in test : 1080


In [12]:
# training config

ent_emb_dim = 50
lr = 0.0004
b_size = 500
margin = 0.5
summary_step = 200

In [13]:
# define the model and criterion

# model = TransEModel(ent_emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2')
model = DistMultModel(ent_emb_dim, kg_train.n_ent, kg_train.n_rel)
criterion = MarginLoss(margin)

In [14]:
# use cuda if it is available

if cuda.is_available():
    cuda.empty_cache()
    model.cuda()
    criterion.cuda()  

In [15]:
# define the optimizer and dataloader, sampler

optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
sampler = BernoulliNegativeSampler(kg_train)
tr_dl = DataLoader(kg_train, batch_size = b_size, use_cuda='all')
val_dl = DataLoader(kg_valid, batch_size = b_size, use_cuda = 'all')

In [16]:
# training

n_epochs = 2000
best_val_loss = 1e+10
iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    
    tr_loss = 0
    model.train()
    
    for step, batch in enumerate(tr_dl):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r) # negative head, negative tail
        
        optimizer.zero_grad()
        
        pos, neg = model(h, t, n_h, n_t, r)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    tr_loss /= (step+1)
    
    if model.training:
        model.eval()
    val_loss = 0
    # for step, batch in tqdm(enumerate(val_dl), desc='steps', total=len(val_dl)):
    for step, batch in enumerate(val_dl):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)
        with torch.no_grad():
            pos, neg = model(h, t, n_h, n_t, r)
            loss = criterion(pos, neg)
            val_loss += loss.item()
    val_loss /= (step+1)
    # iterator.set_description('Epoch {} | mean loss: {:.5f}, valid loss: {:.5f}'.format(epoch+1, tr_loss, val_loss))
    if epoch % 50 == 0:
        tqdm.write('Epoch {} | mean loss: {:.5f}, valid loss: {:.5f}'.format(epoch+1, tr_loss, val_loss))
    model.normalize_parameters()
    
    is_best = val_loss < best_val_loss
    if is_best:
        state = {'epoch': epoch, 
                 'state_dict': model.state_dict(), 
                 'optimizer': optimizer.state_dict()}
        torch.save(state, '../experiment/wiki_country/best_transe.tar')
        best_val_loss = val_loss

  0%|          | 0/2000 [00:00<?, ?epoch/s]

Epoch 1 | mean loss: 242.02944, valid loss: 178.71754
Epoch 51 | mean loss: 89.85932, valid loss: 73.36523
Epoch 101 | mean loss: 22.11345, valid loss: 33.34391
Epoch 151 | mean loss: 8.63050, valid loss: 22.15195
Epoch 201 | mean loss: 4.05967, valid loss: 18.34913
Epoch 251 | mean loss: 3.05279, valid loss: 15.67769
Epoch 301 | mean loss: 2.23112, valid loss: 12.65838
Epoch 351 | mean loss: 1.55090, valid loss: 13.33125
Epoch 401 | mean loss: 1.64024, valid loss: 11.20280
Epoch 451 | mean loss: 1.19760, valid loss: 12.98193
Epoch 501 | mean loss: 1.10246, valid loss: 12.46544
Epoch 551 | mean loss: 0.95250, valid loss: 10.99642
Epoch 601 | mean loss: 0.97751, valid loss: 10.96828
Epoch 651 | mean loss: 0.93941, valid loss: 11.31976
Epoch 701 | mean loss: 0.88975, valid loss: 10.87393
Epoch 751 | mean loss: 0.66451, valid loss: 10.40798
Epoch 801 | mean loss: 0.66126, valid loss: 9.88012
Epoch 851 | mean loss: 0.89190, valid loss: 10.26491
Epoch 901 | mean loss: 0.72292, valid loss: 1