In [8]:
from poem.instance_creation_factories.triples_factory import TriplesFactory
from poem.models.unimodal.ermlp import ERMLP
from torch import optim
import numpy as np
from poem.training.owa import OWATrainingLoop
import os
import logging
from poem.evaluation import RankBasedEvaluator
import sys
import timeit
from torch import nn
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)

**Step 1: Create instances**

In [2]:
path_to_training_data = '../tests/resources/test.txt'

In [3]:
factory = TriplesFactory(path=path_to_training_data)

**Step 2: Configure KGE model**

In [4]:
kge_model = ERMLP(
    triples_factory = factory,
    embedding_dim = 50,
    criterion = nn.MarginRankingLoss(margin=1., reduction='mean'),
    random_seed = 2, 
    )

**Step 3: Configure Training Loop**

In [5]:
optimizer = optim.Adam(params=kge_model.get_grad_params())
owa_training_loop = OWATrainingLoop(model=kge_model, optimizer=optimizer)

**Step 4: Train**

In [6]:
losses = owa_training_loop.train(
    num_epochs=5,
    batch_size=256
)

Training epoch on cpu:  20%|██        | 1/5 [00:02<00:09,  2.38s/it]

Losses: [0.9411756521545025]


Training epoch on cpu:  40%|████      | 2/5 [00:04<00:07,  2.36s/it]

Losses: [0.9411756521545025, 0.8541601078817151]


Training epoch on cpu:  60%|██████    | 3/5 [00:07<00:04,  2.35s/it]

Losses: [0.9411756521545025, 0.8541601078817151, 0.794050917429715]


Training epoch on cpu:  80%|████████  | 4/5 [00:09<00:02,  2.41s/it]

Losses: [0.9411756521545025, 0.8541601078817151, 0.794050917429715, 0.740356265172695]


Training epoch on cpu: 100%|██████████| 5/5 [00:12<00:00,  2.53s/it]

Losses: [0.9411756521545025, 0.8541601078817151, 0.794050917429715, 0.740356265172695, 0.6980177377377447]





In [9]:
evaluator = RankBasedEvaluator(kge_model, filter_neg_triples=True)

In [10]:
test_triples = factory.map_triples_to_id(path_to_triples=path_to_training_data)

In [11]:
# Since filtering the triples requires to load all triples into memory, the batch size has to be reduced.
evaluator.evaluate(test_triples, batch_size=1024)

⚡️ Evaluating triples : 100%|██████████| 59.1k/59.1k [05:23<00:00, 186triple(s)/s]


MetricResults(mean_rank=4094.606752890589, mean_reciprocal_rank=0.017495894770699665, adjusted_mean_rank=0.6028128862380981, adjusted_mean_reciprocal_rank=227.02249145507812, hits_at_k={1: 0.017495894770699665, 3: 0.037082493947960926, 5: 0.047519087200148975, 10: 0.06264495268405816})