In [2]:
from poem.instance_creation_factories.triples_numeric_literals_factory import TriplesNumericLiteralsFactory
from poem.kge_models.kge_models_using_numerical_literals.complex_literal_cwa import ComplexLiteralCWA
from poem.preprocessing.triples_preprocessing_utils.basic_triple_utils import create_entity_and_relation_mappings, \
    load_triples
from torch import optim
import numpy as np
from poem.training_loops.cwa_training_loop import CWATrainingLoop

**Step 1: Create insances**

In [3]:
path_to_training_data = '../tests/resources/test.txt'
path_to_literals = '../tests/resources/numerical_literals.txt'

In [4]:
training_triples = load_triples(path=path_to_training_data)
literals = load_triples(path=path_to_literals)

In [5]:
entity_to_id, relation_to_id = create_entity_and_relation_mappings(triples=training_triples)

In [6]:
factory = TriplesNumericLiteralsFactory(entity_to_id=entity_to_id,
                                        relation_to_id=relation_to_id,
                                        numeric_triples=literals)
instances = factory.create_cwa_instances(triples=training_triples)

**Step 2: Configure KGE model**

In [7]:
kge_model = ComplexLiteralCWA(embedding_dim=200,
                              num_entities=len(entity_to_id),
                              num_relations=len(relation_to_id),
                              input_dropout=0.2,
                              multimodal_data=instances.multimodal_data)

**Step 3: Configure training loop**

In [8]:
parameters = filter(lambda p: p.requires_grad, kge_model.parameters())
optimizer = optim.Adam(params=parameters)
cwa_training_loop = CWATrainingLoop(kge_model=kge_model, optimizer=optimizer)

**Step 4: Train KGE model**

In [9]:
fitted_kge_model, losses = cwa_training_loop.train(training_instances=instances,
                                                   num_epochs=2,
                                                   batch_size=128)

Training epoch: 100%|██████████| 2/2 [06:25<00:00, 189.36s/it]


In [10]:
losses

[4.002053277573112, 2.0248831789942154]