In [2]:
from poem.constants import EMBEDDING_DIM, NUM_ENTITIES, NUM_RELATIONS, INPUT_DROPOUT
from poem.instance_creation_factories.triples_factory import TriplesFactory
from poem.kge_models.unimodal_kge_models.complex_cwa import ComplexCWA
from poem.preprocessing.triples_preprocessing_utils.basic_triple_utils import create_entity_and_relation_mappings, \
    load_triples
from torch import optim
import numpy as np
from poem.training_loops.cwa_training_loop import CWATrainingLoop

**Step 1: Create instances**

In [12]:
path_to_training_data = '../tests/resources/test.txt'

In [14]:
training_triples = load_triples(path=path_to_training_data)
entity_to_id, relation_to_id = create_entity_and_relation_mappings(triples=training_triples)
factory = TriplesFactory(entity_to_id=entity_to_id, relation_to_id=relation_to_id)
instances = factory.create_cwa_instances(triples=training_triples)

**Step 2: Configure KGE model**

In [18]:
num_entities = len(entity_to_id)
num_relations = len(relation_to_id)
kge_model = ComplexCWA(num_entities=num_entities,
                       num_relations=num_relations,
                       embedding_dim=200,
                       input_dropout=0.2
                      )

**Step 3: Configure Training Loop**

In [19]:
parameters = filter(lambda p: p.requires_grad, kge_model.parameters())
optimizer = optim.Adam(params=parameters)
all_entities = np.array(list(entity_to_id.values()))
cwa_training_loop = CWATrainingLoop(kge_model=kge_model, optimizer=optimizer, all_entities=all_entities)

**Step 4: Train**

In [20]:
fitted_kge_model, losses = cwa_training_loop.train(training_instances=instances,
                                                   num_epochs=5,
                                                   batch_size=256)

Training epoch: 100%|██████████| 5/5 [03:34<00:00, 42.68s/it]


In [21]:
losses

[0.7229566037160764,
 0.6996402254101819,
 0.3058661299764912,
 0.07312749471104384,
 0.03328783938829509]