diff --git a/src/pykeen/triples/triples_factory.py b/src/pykeen/triples/triples_factory.py index 07b6c709a2..ea81183297 100644 --- a/src/pykeen/triples/triples_factory.py +++ b/src/pykeen/triples/triples_factory.py @@ -7,7 +7,7 @@ import logging import os import re -from typing import Callable, Collection, List, Mapping, Optional, Sequence, Set, TextIO, Union +from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Sequence, Set, TextIO, Union import numpy as np import pandas as pd @@ -164,6 +164,9 @@ class TriplesFactory: #: Whether to create inverse triples create_inverse_triples: bool = False + #: Arbitrary metadata to go with the graph + metadata: Optional[Dict[str, Any]] = None + # The following fields get generated automatically #: The inverse mapping for entity_label_to_id; initialized automatically @@ -190,6 +193,9 @@ def __post_init__(self): self.entity_id_to_label = invert_mapping(mapping=self.entity_to_id) self.relation_id_to_label = invert_mapping(mapping=self.relation_to_id) + if self.metadata is None: + self.metadata = {} + # vectorized versions self._vectorized_entity_mapper = np.vectorize(self.entity_to_id.get) self._vectorized_relation_mapper = np.vectorize(self.relation_to_id.get) @@ -205,6 +211,7 @@ def from_labeled_triples( relation_to_id: Optional[RelationMapping] = None, compact_id: bool = True, filter_out_candidate_inverse_relations: bool = True, + metadata: Optional[Dict[str, Any]] = None, ) -> 'TriplesFactory': """ Create a new triples factory from label-based triples. @@ -221,6 +228,8 @@ def from_labeled_triples( Whether to compact IDs such that the IDs are consecutive. :param filter_out_candidate_inverse_relations: Whether to remove triples with relations with the inverse suffix. + :param metadata: + Arbitrary key/value pairs to store as metadata :return: A new triples factory. @@ -270,6 +279,7 @@ def from_labeled_triples( relation_to_id=relation_to_id, mapped_triples=mapped_triples, create_inverse_triples=create_inverse_triples, + metadata=metadata, ) @classmethod @@ -280,6 +290,7 @@ def from_path( entity_to_id: Optional[EntityMapping] = None, relation_to_id: Optional[RelationMapping] = None, compact_id: bool = True, + metadata: Optional[Dict[str, Any]] = None, ) -> 'TriplesFactory': """ Create a new triples factory from triples stored in a file. @@ -294,6 +305,10 @@ def from_path( The mapping from relations labels to ID. If None, create a new one from the triples. :param compact_id: Whether to compact IDs such that the IDs are consecutive. + :param metadata: + Arbitrary key/value pairs to store as metadata with the triples factory. Do not + include ``path`` as a key because it is automatically taken from the ``path`` + kwarg to this function. :return: A new triples factory. @@ -314,11 +329,17 @@ def from_path( entity_to_id=entity_to_id, relation_to_id=relation_to_id, compact_id=compact_id, + metadata={ + 'path': path, + **(metadata or {}), + }, ) def clone_and_exchange_triples( self, mapped_triples: MappedTriples, + extra_metadata: Optional[Dict[str, Any]] = None, + keep_metadata: bool = True, ) -> "TriplesFactory": """ Create a new triples factory sharing everything except the triples. @@ -328,6 +349,11 @@ def clone_and_exchange_triples( :param mapped_triples: The new mapped triples. + :param extra_metadata: + Extra metadata to include in the new triples factory. If ``keep_metadata`` is true, + the dictionaries will be unioned with precedence taken on keys from ``extra_metadata``. + :param keep_metadata: + Pass the current factory's metadata to the new triples factory :return: The new factory. @@ -337,6 +363,10 @@ def clone_and_exchange_triples( relation_to_id=self.relation_to_id, mapped_triples=mapped_triples, create_inverse_triples=self.create_inverse_triples, + metadata={ + **(extra_metadata or {}), + **(self.metadata if keep_metadata else {}), + }, ) @property @@ -369,11 +399,16 @@ def triples(self) -> np.ndarray: # noqa: D401 def extra_repr(self) -> str: """Extra representation string.""" - return ( - f"num_entities={self.num_entities}, " - f"num_relations={self.num_relations}, " - f"num_triples={self.num_triples}, " - f"inverse_triples={self.create_inverse_triples}" + d = [ + ('num_entities', self.num_entities), + ('num_relations', self.num_relations), + ('num_triples', self.num_triples), + ('inverse_triples', self.create_inverse_triples), + ] + d.extend(sorted(self.metadata.items())) + return ', '.join( + f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' + for k, v in d ) def __repr__(self): # noqa: D105 @@ -683,14 +718,17 @@ def new_with_restriction( """ keep_mask = None + extra_metadata = {} # Filter for entities if entities is not None: + extra_metadata['entity_restriction'] = entities keep_mask = self.get_mask_for_entities(entities=entities, invert=invert_entity_selection) remaining_entities = self.num_entities - len(entities) if invert_entity_selection else len(entities) logger.info(f"keeping {format_relative_comparison(remaining_entities, self.num_entities)} entities.") # Filter for relations if relations is not None: + extra_metadata['relation_restriction'] = relations relation_mask = self.get_mask_for_relations(relations=relations, invert=invert_relation_selection) remaining_relations = self.num_relations - len(relations) if invert_entity_selection else len(relations) logger.info(f"keeping {format_relative_comparison(remaining_relations, self.num_relations)} relations.") @@ -702,4 +740,7 @@ def new_with_restriction( num_triples = keep_mask.sum() logger.info(f"keeping {format_relative_comparison(num_triples, self.num_triples)} triples.") - return self.clone_and_exchange_triples(mapped_triples=self.mapped_triples[keep_mask]) + return self.clone_and_exchange_triples( + mapped_triples=self.mapped_triples[keep_mask], + extra_metadata=extra_metadata, + ) diff --git a/tests/test_triples_factory.py b/tests/test_triples_factory.py index c26fc7b980..0b0831b7c2 100644 --- a/tests/test_triples_factory.py +++ b/tests/test_triples_factory.py @@ -11,6 +11,7 @@ import torch from pykeen.datasets import Nations +from pykeen.datasets.nations import NATIONS_TRAIN_PATH from pykeen.triples import LCWAInstances, TriplesFactory, TriplesNumericLiteralsFactory from pykeen.triples.generation import generate_triples from pykeen.triples.splitting import ( @@ -432,6 +433,52 @@ def test_inverse_triples(self): msg='Wrong number of relations in factory', ) + def test_metadata(self): + """Test metadata passing for triples factories.""" + t = Nations().training + self.assertEqual(NATIONS_TRAIN_PATH, t.metadata['path']) + self.assertEqual( + ( + f'TriplesFactory(num_entities=14, num_relations=55, num_triples=1592,' + f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}")' + ), + repr(t), + ) + + entities = ['poland', 'ussr'] + x = t.new_with_restriction(entities=entities) + self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path']) + self.assertEqual( + ( + f'TriplesFactory(num_entities=14, num_relations=55, num_triples=37,' + f' inverse_triples=False, entity_restriction={repr(entities)}, path="{NATIONS_TRAIN_PATH}")' + ), + repr(x), + ) + + relations = ['negativebehavior'] + v = t.new_with_restriction(relations=relations) + self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path']) + self.assertEqual( + ( + f'TriplesFactory(num_entities=14, num_relations=55, num_triples=29,' + f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}", relation_restriction={repr(relations)})' + ), + repr(v), + ) + + w = t.clone_and_exchange_triples(t.triples[0:5], keep_metadata=False) + self.assertIsInstance(w, TriplesFactory) + self.assertNotIn('path', w.metadata) + self.assertEqual( + 'TriplesFactory(num_entities=14, num_relations=55, num_triples=5, inverse_triples=False)', + repr(w), + ) + + y, z = t.split() + self.assertEqual(NATIONS_TRAIN_PATH, y.metadata['path']) + self.assertEqual(NATIONS_TRAIN_PATH, z.metadata['path']) + def test_get_absolute_split_sizes(): """Test get_absolute_split_sizes."""