Skip to content

Commit

Permalink
Add generalized metadata storage to TriplesFactory and improve repr (#…
Browse files Browse the repository at this point in the history
…211)

Closes #142
  • Loading branch information
cthoyt committed Dec 11, 2020
1 parent f25c011 commit c5423d8
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 7 deletions.
55 changes: 48 additions & 7 deletions src/pykeen/triples/triples_factory.py
Expand Up @@ -7,7 +7,7 @@
import logging
import os
import re
from typing import Callable, Collection, List, Mapping, Optional, Sequence, Set, TextIO, Union
from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Sequence, Set, TextIO, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -164,6 +164,9 @@ class TriplesFactory:
#: Whether to create inverse triples
create_inverse_triples: bool = False

#: Arbitrary metadata to go with the graph
metadata: Optional[Dict[str, Any]] = None

# The following fields get generated automatically

#: The inverse mapping for entity_label_to_id; initialized automatically
Expand All @@ -190,6 +193,9 @@ def __post_init__(self):
self.entity_id_to_label = invert_mapping(mapping=self.entity_to_id)
self.relation_id_to_label = invert_mapping(mapping=self.relation_to_id)

if self.metadata is None:
self.metadata = {}

# vectorized versions
self._vectorized_entity_mapper = np.vectorize(self.entity_to_id.get)
self._vectorized_relation_mapper = np.vectorize(self.relation_to_id.get)
Expand All @@ -205,6 +211,7 @@ def from_labeled_triples(
relation_to_id: Optional[RelationMapping] = None,
compact_id: bool = True,
filter_out_candidate_inverse_relations: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> 'TriplesFactory':
"""
Create a new triples factory from label-based triples.
Expand All @@ -221,6 +228,8 @@ def from_labeled_triples(
Whether to compact IDs such that the IDs are consecutive.
:param filter_out_candidate_inverse_relations:
Whether to remove triples with relations with the inverse suffix.
:param metadata:
Arbitrary key/value pairs to store as metadata
:return:
A new triples factory.
Expand Down Expand Up @@ -270,6 +279,7 @@ def from_labeled_triples(
relation_to_id=relation_to_id,
mapped_triples=mapped_triples,
create_inverse_triples=create_inverse_triples,
metadata=metadata,
)

@classmethod
Expand All @@ -280,6 +290,7 @@ def from_path(
entity_to_id: Optional[EntityMapping] = None,
relation_to_id: Optional[RelationMapping] = None,
compact_id: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> 'TriplesFactory':
"""
Create a new triples factory from triples stored in a file.
Expand All @@ -294,6 +305,10 @@ def from_path(
The mapping from relations labels to ID. If None, create a new one from the triples.
:param compact_id:
Whether to compact IDs such that the IDs are consecutive.
:param metadata:
Arbitrary key/value pairs to store as metadata with the triples factory. Do not
include ``path`` as a key because it is automatically taken from the ``path``
kwarg to this function.
:return:
A new triples factory.
Expand All @@ -314,11 +329,17 @@ def from_path(
entity_to_id=entity_to_id,
relation_to_id=relation_to_id,
compact_id=compact_id,
metadata={
'path': path,
**(metadata or {}),
},
)

def clone_and_exchange_triples(
self,
mapped_triples: MappedTriples,
extra_metadata: Optional[Dict[str, Any]] = None,
keep_metadata: bool = True,
) -> "TriplesFactory":
"""
Create a new triples factory sharing everything except the triples.
Expand All @@ -328,6 +349,11 @@ def clone_and_exchange_triples(
:param mapped_triples:
The new mapped triples.
:param extra_metadata:
Extra metadata to include in the new triples factory. If ``keep_metadata`` is true,
the dictionaries will be unioned with precedence taken on keys from ``extra_metadata``.
:param keep_metadata:
Pass the current factory's metadata to the new triples factory
:return:
The new factory.
Expand All @@ -337,6 +363,10 @@ def clone_and_exchange_triples(
relation_to_id=self.relation_to_id,
mapped_triples=mapped_triples,
create_inverse_triples=self.create_inverse_triples,
metadata={
**(extra_metadata or {}),
**(self.metadata if keep_metadata else {}),
},
)

@property
Expand Down Expand Up @@ -369,11 +399,16 @@ def triples(self) -> np.ndarray: # noqa: D401

def extra_repr(self) -> str:
"""Extra representation string."""
return (
f"num_entities={self.num_entities}, "
f"num_relations={self.num_relations}, "
f"num_triples={self.num_triples}, "
f"inverse_triples={self.create_inverse_triples}"
d = [
('num_entities', self.num_entities),
('num_relations', self.num_relations),
('num_triples', self.num_triples),
('inverse_triples', self.create_inverse_triples),
]
d.extend(sorted(self.metadata.items()))
return ', '.join(
f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}'
for k, v in d
)

def __repr__(self): # noqa: D105
Expand Down Expand Up @@ -683,14 +718,17 @@ def new_with_restriction(
"""
keep_mask = None

extra_metadata = {}
# Filter for entities
if entities is not None:
extra_metadata['entity_restriction'] = entities
keep_mask = self.get_mask_for_entities(entities=entities, invert=invert_entity_selection)
remaining_entities = self.num_entities - len(entities) if invert_entity_selection else len(entities)
logger.info(f"keeping {format_relative_comparison(remaining_entities, self.num_entities)} entities.")

# Filter for relations
if relations is not None:
extra_metadata['relation_restriction'] = relations
relation_mask = self.get_mask_for_relations(relations=relations, invert=invert_relation_selection)
remaining_relations = self.num_relations - len(relations) if invert_entity_selection else len(relations)
logger.info(f"keeping {format_relative_comparison(remaining_relations, self.num_relations)} relations.")
Expand All @@ -702,4 +740,7 @@ def new_with_restriction(

num_triples = keep_mask.sum()
logger.info(f"keeping {format_relative_comparison(num_triples, self.num_triples)} triples.")
return self.clone_and_exchange_triples(mapped_triples=self.mapped_triples[keep_mask])
return self.clone_and_exchange_triples(
mapped_triples=self.mapped_triples[keep_mask],
extra_metadata=extra_metadata,
)
47 changes: 47 additions & 0 deletions tests/test_triples_factory.py
Expand Up @@ -11,6 +11,7 @@
import torch

from pykeen.datasets import Nations
from pykeen.datasets.nations import NATIONS_TRAIN_PATH
from pykeen.triples import LCWAInstances, TriplesFactory, TriplesNumericLiteralsFactory
from pykeen.triples.generation import generate_triples
from pykeen.triples.splitting import (
Expand Down Expand Up @@ -432,6 +433,52 @@ def test_inverse_triples(self):
msg='Wrong number of relations in factory',
)

def test_metadata(self):
"""Test metadata passing for triples factories."""
t = Nations().training
self.assertEqual(NATIONS_TRAIN_PATH, t.metadata['path'])
self.assertEqual(
(
f'TriplesFactory(num_entities=14, num_relations=55, num_triples=1592,'
f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}")'
),
repr(t),
)

entities = ['poland', 'ussr']
x = t.new_with_restriction(entities=entities)
self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path'])
self.assertEqual(
(
f'TriplesFactory(num_entities=14, num_relations=55, num_triples=37,'
f' inverse_triples=False, entity_restriction={repr(entities)}, path="{NATIONS_TRAIN_PATH}")'
),
repr(x),
)

relations = ['negativebehavior']
v = t.new_with_restriction(relations=relations)
self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path'])
self.assertEqual(
(
f'TriplesFactory(num_entities=14, num_relations=55, num_triples=29,'
f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}", relation_restriction={repr(relations)})'
),
repr(v),
)

w = t.clone_and_exchange_triples(t.triples[0:5], keep_metadata=False)
self.assertIsInstance(w, TriplesFactory)
self.assertNotIn('path', w.metadata)
self.assertEqual(
'TriplesFactory(num_entities=14, num_relations=55, num_triples=5, inverse_triples=False)',
repr(w),
)

y, z = t.split()
self.assertEqual(NATIONS_TRAIN_PATH, y.metadata['path'])
self.assertEqual(NATIONS_TRAIN_PATH, z.metadata['path'])


def test_get_absolute_split_sizes():
"""Test get_absolute_split_sizes."""
Expand Down

0 comments on commit c5423d8

Please sign in to comment.