In [1]:
from poem.instance_creation_factories.triples_factory import TriplesFactory
from poem.models.unimodal.conv_e import ConvE
from torch import optim
import numpy as np
from poem.training.cwa import CWATrainingLoop
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

**Step 1: Create instances**

In [2]:
path_to_training_data = '../tests/resources/test.txt'

In [3]:
# ConvE should always be used with inverse triples
factory = TriplesFactory(path=path_to_training_data, create_inverse_triples=True)

**Step 2: Configure KGE model**

In [4]:
# For the new model
config = dict(
    embedding_dim       = 200,
    input_channels      = 1,
    output_channels     = 32,
    embedding_height    = 10,
    embedding_width     = 20,
    kernel_height       = 3,
    kernel_width        = 3,
    input_dropout       = 0.2,
    feature_map_dropout = 0.2,
    output_dropout      = 0.3,
    preferred_device    = 'gpu',
)

In [5]:
kge_model = ConvE(triples_factory=factory,
                  **config
                 )

**Step 3: Configure Training Loop**

In [6]:
optimizer = optim.Adam(params=kge_model.get_grad_params())
cwa_training_loop = CWATrainingLoop(model=kge_model, optimizer=optimizer)

**Step 4: Train**

In [7]:
losses = cwa_training_loop.train(num_epochs=5,
                                 batch_size=256)

100%|██████████| 118142/118142 [00:00<00:00, 247214.15it/s]
Training epoch on cuda: 100%|██████████| 5/5 [00:52<00:00, 10.41s/it]


In [8]:
losses

[0.08639761942492256,
 0.0033559632430117015,
 0.001854307092430502,
 0.0014784366678397835,
 0.001321700629834906]

**Step 4: Evaluate the model**

In [9]:
from poem.evaluation import RankBasedEvaluator

In [10]:
evaluator = RankBasedEvaluator(kge_model, filter_neg_triples=False)

In [11]:
test_triples = factory.map_triples_to_id(path_to_triples=path_to_training_data)

In [12]:
evaluator.evaluate(test_triples, batch_size=8192)

⚡️ Evaluating triples : 100%|██████████| 59.1k/59.1k [00:05<00:00, 9.20ktriple(s)/s]


MetricResults(mean_rank=3733.1487531953076, mean_reciprocal_rank=0.010182661542889066, adjusted_mean_rank=0.5495989322662354, adjusted_mean_reciprocal_rank=166.6231689453125, hits_at_k={1: 0.010182661542889066, 3: 0.021626517241963062, 5: 0.0346278207580708, 10: 0.05469689018300012})