/
triples_factory.py
1431 lines (1226 loc) 路 53.5 KB
/
triples_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""Implementation of basic instance factory which creates just instances based on standard KG triples."""
import dataclasses
import logging
import pathlib
import re
import warnings
from abc import abstractmethod
from typing import (
Any,
Callable,
ClassVar,
Collection,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Set,
TextIO,
Tuple,
TypeVar,
Union,
cast,
)
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from .instances import BatchedSLCWAInstances, LCWAInstances, SubGraphSLCWAInstances
from .splitting import split
from .utils import TRIPLES_DF_COLUMNS, load_triples, tensor_to_df
from ..constants import COLUMN_LABELS
from ..typing import EntityMapping, LabeledTriples, MappedTriples, RelationMapping, TorchRandomHint
from ..utils import (
ExtraReprMixin,
compact_mapping,
format_relative_comparison,
get_edge_index,
invert_mapping,
normalize_path,
triple_tensor_to_set,
)
__all__ = [
"KGInfo",
"CoreTriplesFactory",
"TriplesFactory",
"create_entity_mapping",
"create_relation_mapping",
"INVERSE_SUFFIX",
"cat_triples",
"splits_steps",
"splits_similarity",
"RelationInverter",
"relation_inverter",
"AnyTriples",
"get_mapped_triples",
]
logger = logging.getLogger(__name__)
INVERSE_SUFFIX = "_inverse"
def create_entity_mapping(triples: LabeledTriples) -> EntityMapping:
"""Create mapping from entity labels to IDs.
:param triples: shape: (n, 3), dtype: str
:returns:
A mapping of entity labels to indices
"""
# Split triples
heads, tails = triples[:, 0], triples[:, 2]
# Sorting ensures consistent results when the triples are permuted
entity_labels = sorted(set(heads).union(tails))
# Create mapping
return {str(label): i for (i, label) in enumerate(entity_labels)}
def create_relation_mapping(relations: Iterable[str]) -> RelationMapping:
"""Create mapping from relation labels to IDs.
:param relations: A set of relation labels
:returns:
A mapping of relation labels to indices
"""
# Sorting ensures consistent results when the triples are permuted
relation_labels = sorted(
set(relations),
key=lambda x: (re.sub(f"{INVERSE_SUFFIX}$", "", x), x.endswith(f"{INVERSE_SUFFIX}")),
)
# Create mapping
return {str(label): i for (i, label) in enumerate(relation_labels)}
def _map_triples_elements_to_ids(
triples: LabeledTriples,
entity_to_id: EntityMapping,
relation_to_id: RelationMapping,
) -> MappedTriples:
"""Map entities and relations to pre-defined ids."""
if triples.size == 0:
logger.warning("Provided empty triples to map.")
return torch.empty(0, 3, dtype=torch.long)
# When triples that don't exist are trying to be mapped, they get the id "-1"
entity_getter = np.vectorize(entity_to_id.get)
head_column = entity_getter(triples[:, 0:1], [-1])
tail_column = entity_getter(triples[:, 2:3], [-1])
relation_getter = np.vectorize(relation_to_id.get)
relation_column = relation_getter(triples[:, 1:2], [-1])
# Filter all non-existent triples
head_filter = head_column < 0
relation_filter = relation_column < 0
tail_filter = tail_column < 0
num_no_head = head_filter.sum()
num_no_relation = relation_filter.sum()
num_no_tail = tail_filter.sum()
if (num_no_head > 0) or (num_no_relation > 0) or (num_no_tail > 0):
logger.warning(
f"You're trying to map triples with {num_no_head + num_no_tail} entities and {num_no_relation} relations"
f" that are not in the training set. These triples will be excluded from the mapping.",
)
non_mappable_triples = head_filter | relation_filter | tail_filter
head_column = head_column[~non_mappable_triples, None]
relation_column = relation_column[~non_mappable_triples, None]
tail_column = tail_column[~non_mappable_triples, None]
logger.warning(
f"In total {non_mappable_triples.sum():.0f} from {triples.shape[0]:.0f} triples were filtered out",
)
triples_of_ids = np.concatenate([head_column, relation_column, tail_column], axis=1)
triples_of_ids = np.array(triples_of_ids, dtype=np.int64)
# Note: Unique changes the order of the triples
# Note: Using unique means implicit balancing of training samples
unique_mapped_triples = np.unique(ar=triples_of_ids, axis=0)
return torch.tensor(unique_mapped_triples, dtype=torch.long)
def _get_triple_mask(
ids: Collection[int],
triples: MappedTriples,
columns: Union[int, Collection[int]],
invert: bool = False,
max_id: Optional[int] = None,
) -> torch.BoolTensor:
# normalize input
triples = triples[:, columns]
if isinstance(columns, int):
columns = [columns]
mask = torch.isin(
elements=triples,
test_elements=torch.as_tensor(list(ids), dtype=torch.long),
assume_unique=False,
invert=invert,
)
if len(columns) > 1:
mask = mask.all(dim=-1)
return mask
def _ensure_ids(
labels_or_ids: Union[Collection[int], Collection[str]],
label_to_id: Mapping[str, int],
) -> Collection[int]:
"""Convert labels to IDs."""
return [label_to_id[l_or_i] if isinstance(l_or_i, str) else l_or_i for l_or_i in labels_or_ids]
RelationID = TypeVar("RelationID", int, torch.LongTensor)
class RelationInverter:
"""An interface for inverse-relation ID mapping."""
# TODO: method is_inverse?
@abstractmethod
def get_inverse_id(self, relation_id: RelationID) -> RelationID:
"""Get the inverse ID for a given relation."""
# TODO: inverse of inverse?
raise NotImplementedError
@abstractmethod
def _map(self, batch: torch.LongTensor, index: int = 1) -> torch.LongTensor:
raise NotImplementedError
@abstractmethod
def invert_(self, batch: torch.LongTensor, index: int = 1) -> torch.LongTensor:
"""Invert relations in a batch (in-place)."""
raise NotImplementedError
def map(self, batch: torch.LongTensor, index: int = 1, invert: bool = False) -> torch.LongTensor:
"""Map relations of batch, optionally also inverting them."""
batch = self._map(batch=batch, index=index)
return self.invert_(batch=batch, index=index) if invert else batch
class DefaultRelationInverter(RelationInverter):
"""Maps normal relations to even IDs, and the corresponding inverse to the next odd ID."""
# docstr-coverage: inherited
def get_inverse_id(self, relation_id: RelationID) -> RelationID: # noqa: D102
return relation_id + 1
# docstr-coverage: inherited
def _map(self, batch: torch.LongTensor, index: int = 1, invert: bool = False) -> torch.LongTensor: # noqa: D102
batch = batch.clone()
batch[:, index] *= 2
return batch
# docstr-coverage: inherited
def invert_(self, batch: torch.LongTensor, index: int = 1) -> torch.LongTensor: # noqa: D102
# The number of relations stored in the triples factory includes the number of inverse relations
# Id of inverse relation: relation + 1
batch[:, index] += 1
return batch
relation_inverter = DefaultRelationInverter()
@dataclasses.dataclass
class Labeling:
"""A mapping between labels and IDs."""
#: The mapping from labels to IDs.
label_to_id: Mapping[str, int]
#: The inverse mapping for label_to_id; initialized automatically
id_to_label: Mapping[int, str] = dataclasses.field(init=False)
#: A vectorized version of entity_label_to_id; initialized automatically
_vectorized_mapper: Callable[..., np.ndarray] = dataclasses.field(init=False, compare=False)
#: A vectorized version of entity_id_to_label; initialized automatically
_vectorized_labeler: Callable[..., np.ndarray] = dataclasses.field(init=False, compare=False)
def __post_init__(self):
"""Precompute inverse mappings."""
self.id_to_label = invert_mapping(mapping=self.label_to_id)
self._vectorized_mapper = np.vectorize(self.label_to_id.get, otypes=[int])
self._vectorized_labeler = np.vectorize(self.id_to_label.get, otypes=[str])
def label(
self,
ids: Union[int, Sequence[int], np.ndarray, torch.LongTensor],
unknown_label: str = "unknown",
) -> np.ndarray:
"""Convert IDs to labels."""
# Normalize input
if isinstance(ids, torch.Tensor):
ids = ids.cpu().numpy()
if isinstance(ids, int):
ids = [ids]
ids = np.asanyarray(ids)
# label
return self._vectorized_labeler(ids, (unknown_label,))
@property
def max_id(self) -> int:
"""Return the maximum ID (excl.)."""
return max(self.label_to_id.values()) + 1
def all_labels(self) -> np.ndarray:
"""Get all labels, in order."""
return self.label(range(self.max_id))
def restrict_triples(
mapped_triples: MappedTriples,
entities: Optional[Collection[int]] = None,
relations: Optional[Collection[int]] = None,
invert_entity_selection: bool = False,
invert_relation_selection: bool = False,
) -> MappedTriples:
"""Select a subset of triples based on the given entity and relation ID selection.
:param mapped_triples:
The ID-based triples.
:param entities:
The entity IDs of interest. If None, defaults to all entities.
:param relations:
The relation IDs of interest. If None, defaults to all relations.
:param invert_entity_selection:
Whether to invert the entity selection, i.e. select those triples without the provided entities.
:param invert_relation_selection:
Whether to invert the relation selection, i.e. select those triples without the provided relations.
:return:
A tensor of triples containing the entities and relations of interest.
"""
keep_mask = None
# Filter for entities
if entities is not None:
keep_mask = _get_triple_mask(
ids=entities,
triples=mapped_triples,
columns=(0, 2), # head and entity need to fulfil the requirement
invert=invert_entity_selection,
)
# Filter for relations
if relations is not None:
relation_mask = _get_triple_mask(
ids=relations,
triples=mapped_triples,
columns=1,
invert=invert_relation_selection,
)
keep_mask = relation_mask if keep_mask is None else keep_mask & relation_mask
# No filter
if keep_mask is None:
return mapped_triples
return mapped_triples[keep_mask]
class KGInfo(ExtraReprMixin):
"""An object storing information about the number of entities and relations."""
#: the number of unique entities
num_entities: int
#: the number of relations (maybe including "artificial" inverse relations)
num_relations: int
#: whether to create inverse triples
create_inverse_triples: bool
#: the number of real relations, i.e., without artificial inverses
real_num_relations: int
def __init__(
self,
num_entities: int,
num_relations: int,
create_inverse_triples: bool,
) -> None:
"""
Initialize the information object.
:param num_entities:
the number of entities.
:param num_relations:
the number of relations, excluding artifical inverse relations.
:param create_inverse_triples:
whether to create inverse triples
"""
self.num_entities = num_entities
self.real_num_relations = num_relations
if create_inverse_triples:
num_relations *= 2
self.num_relations = num_relations
self.create_inverse_triples = create_inverse_triples
def iter_extra_repr(self) -> Iterable[str]:
"""Iterate over extra_repr components."""
yield from super().iter_extra_repr()
yield f"num_entities={self.num_entities}"
yield f"num_relations={self.num_relations}"
yield f"create_inverse_triples={self.create_inverse_triples}"
class CoreTriplesFactory(KGInfo):
"""Create instances from ID-based triples."""
triples_file_name: ClassVar[str] = "numeric_triples.tsv.gz"
base_file_name: ClassVar[str] = "base.pth"
def __init__(
self,
mapped_triples: Union[MappedTriples, np.ndarray],
num_entities: int,
num_relations: int,
create_inverse_triples: bool = False,
metadata: Optional[Mapping[str, Any]] = None,
):
"""
Create the triples factory.
:param mapped_triples: shape: (n, 3)
A three-column matrix where each row are the head identifier, relation identifier, then tail identifier.
:param num_entities:
The number of entities.
:param num_relations:
The number of relations.
:param create_inverse_triples:
Whether to create inverse triples.
:param metadata:
Arbitrary metadata to go with the graph
:raises TypeError:
if the mapped_triples are of non-integer dtype
:raises ValueError:
if the mapped_triples are of invalid shape
"""
super().__init__(
num_entities=num_entities,
num_relations=num_relations,
create_inverse_triples=create_inverse_triples,
)
# ensure torch.Tensor
mapped_triples = torch.as_tensor(mapped_triples)
# input validation
if mapped_triples.ndim != 2 or mapped_triples.shape[1] != 3:
raise ValueError(f"Invalid shape for mapped_triples: {mapped_triples.shape}; must be (n, 3)")
if mapped_triples.is_complex() or mapped_triples.is_floating_point():
raise TypeError(f"Invalid type: {mapped_triples.dtype}. Must be integer dtype.")
# always store as torch.long, i.e., torch's default integer dtype
self.mapped_triples = mapped_triples.to(dtype=torch.long)
if metadata is None:
metadata = dict()
self.metadata = metadata
@classmethod
def create(
cls,
mapped_triples: MappedTriples,
num_entities: Optional[int] = None,
num_relations: Optional[int] = None,
create_inverse_triples: bool = False,
metadata: Optional[Mapping[str, Any]] = None,
) -> "CoreTriplesFactory":
"""
Create a triples factory without any label information.
:param mapped_triples: shape: (n, 3)
The ID-based triples.
:param num_entities:
The number of entities. If not given, inferred from mapped_triples.
:param num_relations:
The number of relations. If not given, inferred from mapped_triples.
:param create_inverse_triples:
Whether to create inverse triples.
:param metadata:
Additional metadata to store in the factory.
:return:
A new triples factory.
"""
if num_entities is None:
num_entities = mapped_triples[:, [0, 2]].max().item() + 1
if num_relations is None:
num_relations = mapped_triples[:, 1].max().item() + 1
return CoreTriplesFactory(
mapped_triples=mapped_triples,
num_entities=num_entities,
num_relations=num_relations,
create_inverse_triples=create_inverse_triples,
metadata=metadata,
)
def __eq__(self, __o: object) -> bool: # noqa: D105
if not isinstance(__o, CoreTriplesFactory):
return False
return (
(self.num_entities == __o.num_entities)
and (self.num_relations == __o.num_relations)
and (self.num_triples == __o.num_triples)
and (self.create_inverse_triples == __o.create_inverse_triples)
and bool((self.mapped_triples == __o.mapped_triples).all().item())
)
@property
def num_triples(self) -> int: # noqa: D401
"""The number of triples."""
return self.mapped_triples.shape[0]
def iter_extra_repr(self) -> Iterable[str]:
"""Iterate over extra_repr components."""
yield from super().iter_extra_repr()
yield f"num_triples={self.num_triples}"
for k, v in sorted(self.metadata.items()):
if isinstance(v, (str, pathlib.Path)):
v = f'"{v}"'
yield f"{k}={v}"
def with_labels(
self,
entity_to_id: Mapping[str, int],
relation_to_id: Mapping[str, int],
) -> "TriplesFactory":
"""Add labeling to the TriplesFactory."""
# check new label to ID mappings
for name, columns, new_labeling in (
("entity", [0, 2], entity_to_id),
("relation", 1, relation_to_id),
):
existing_ids = set(self.mapped_triples[:, columns].unique().tolist())
if not existing_ids.issubset(new_labeling.values()):
diff = existing_ids.difference(new_labeling.values())
raise ValueError(f"Some existing IDs do not occur in the new {name} labeling: {diff}")
return TriplesFactory(
mapped_triples=self.mapped_triples,
entity_to_id=entity_to_id,
relation_to_id=relation_to_id,
create_inverse_triples=self.create_inverse_triples,
metadata=self.metadata,
)
def get_inverse_relation_id(self, relation: int) -> int:
"""Get the inverse relation identifier for the given relation."""
if not self.create_inverse_triples:
raise ValueError("Can not get inverse triple, they have not been created.")
return relation_inverter.get_inverse_id(relation_id=relation)
def _add_inverse_triples_if_necessary(self, mapped_triples: MappedTriples) -> MappedTriples:
"""Add inverse triples if they shall be created."""
if not self.create_inverse_triples:
return mapped_triples
logger.info("Creating inverse triples.")
return torch.cat(
[
relation_inverter.map(batch=mapped_triples),
relation_inverter.map(batch=mapped_triples, invert=True).flip(1),
]
)
def create_slcwa_instances(self, *, sampler: Optional[str] = None, **kwargs) -> Dataset:
"""Create sLCWA instances for this factory's triples."""
cls = BatchedSLCWAInstances if sampler is None else SubGraphSLCWAInstances
if "shuffle" in kwargs:
if kwargs.pop("shuffle"):
warnings.warn("Training instances are always shuffled.", DeprecationWarning)
else:
raise AssertionError("If shuffle is provided, it must be True.")
return cls(
mapped_triples=self._add_inverse_triples_if_necessary(mapped_triples=self.mapped_triples),
num_entities=self.num_entities,
num_relations=self.num_relations,
**kwargs,
)
def create_lcwa_instances(self, use_tqdm: Optional[bool] = None, target: Optional[int] = None) -> Dataset:
"""Create LCWA instances for this factory's triples."""
return LCWAInstances.from_triples(
mapped_triples=self._add_inverse_triples_if_necessary(mapped_triples=self.mapped_triples),
num_entities=self.num_entities,
num_relations=self.num_relations,
target=target,
)
def get_most_frequent_relations(self, n: Union[int, float]) -> Set[int]:
"""Get the IDs of the n most frequent relations.
:param n:
Either the (integer) number of top relations to keep or the (float) percentage of top relationships to keep.
:returns:
A set of IDs for the n most frequent relations
:raises TypeError:
If the n is the wrong type
"""
logger.info(f"applying cutoff of {n} to {self}")
if isinstance(n, float):
assert 0 < n < 1
n = int(self.num_relations * n)
elif not isinstance(n, int):
raise TypeError("n must be either an integer or a float")
uniq, counts = self.mapped_triples[:, 1].unique(return_counts=True)
top_counts, top_ids = counts.topk(k=n, largest=True)
return set(uniq[top_ids].tolist())
def clone_and_exchange_triples(
self,
mapped_triples: MappedTriples,
extra_metadata: Optional[Dict[str, Any]] = None,
keep_metadata: bool = True,
create_inverse_triples: Optional[bool] = None,
) -> "CoreTriplesFactory":
"""
Create a new triples factory sharing everything except the triples.
.. note ::
We use shallow copies.
:param mapped_triples:
The new mapped triples.
:param extra_metadata:
Extra metadata to include in the new triples factory. If ``keep_metadata`` is true,
the dictionaries will be unioned with precedence taken on keys from ``extra_metadata``.
:param keep_metadata:
Pass the current factory's metadata to the new triples factory
:param create_inverse_triples:
Change inverse triple creation flag. If None, use flag from this factory.
:return:
The new factory.
"""
if create_inverse_triples is None:
create_inverse_triples = self.create_inverse_triples
return CoreTriplesFactory(
mapped_triples=mapped_triples,
num_entities=self.num_entities,
num_relations=self.real_num_relations,
create_inverse_triples=create_inverse_triples,
metadata={
**(extra_metadata or {}),
**(self.metadata if keep_metadata else {}), # type: ignore
},
)
def split(
self,
ratios: Union[float, Sequence[float]] = 0.8,
*,
random_state: TorchRandomHint = None,
randomize_cleanup: bool = False,
method: Optional[str] = None,
) -> List["CoreTriplesFactory"]:
"""Split a triples factory into a train/test.
:param ratios:
There are three options for this argument:
1. A float can be given between 0 and 1.0, non-inclusive. The first set of triples will
get this ratio and the second will get the rest.
2. A list of ratios can be given for which set in which order should get what ratios as in
``[0.8, 0.1]``. The final ratio can be omitted because that can be calculated.
3. All ratios can be explicitly set in order such as in ``[0.8, 0.1, 0.1]``
where the sum of all ratios is 1.0.
:param random_state:
The random state used to shuffle and split the triples.
:param randomize_cleanup:
If true, uses the non-deterministic method for moving triples to the training set. This has the
advantage that it does not necessarily have to move all of them, but it might be significantly
slower since it moves one triple at a time.
:param method:
The name of the method to use, from SPLIT_METHODS. Defaults to "coverage".
:return:
A partition of triples, which are split (approximately) according to the ratios, stored TriplesFactory's
which share everything else with this root triples factory.
.. code-block:: python
ratio = 0.8 # makes a [0.8, 0.2] split
training_factory, testing_factory = factory.split(ratio)
ratios = [0.8, 0.1] # makes a [0.8, 0.1, 0.1] split
training_factory, testing_factory, validation_factory = factory.split(ratios)
ratios = [0.8, 0.1, 0.1] # also makes a [0.8, 0.1, 0.1] split
training_factory, testing_factory, validation_factory = factory.split(ratios)
"""
# Make new triples factories for each group
return [
self.clone_and_exchange_triples(
mapped_triples=triples,
# do not explicitly create inverse triples for testing; this is handled by the evaluation code
create_inverse_triples=None if i == 0 else False,
)
for i, triples in enumerate(
split(
mapped_triples=self.mapped_triples,
ratios=ratios,
random_state=random_state,
randomize_cleanup=randomize_cleanup,
method=method,
)
)
]
def entities_to_ids(self, entities: Union[Collection[int], Collection[str]]) -> Collection[int]:
"""Normalize entities to IDs.
:param entities: A collection of either integer identifiers for entities or
string labels for entities (that will get auto-converted)
:returns: Integer identifiers for entities
:raises ValueError: If the ``entities`` passed are string labels
and this triples factory does not have an entity label to identifier mapping
(e.g., it's just a base :class:`CoreTriplesFactory` instance)
"""
for e in entities:
if not isinstance(e, int):
raise ValueError(f"{self.__class__.__name__} cannot convert entity IDs from {type(e)} to int.")
return cast(Collection[int], entities)
def relations_to_ids(self, relations: Union[Collection[int], Collection[str]]) -> Collection[int]:
"""Normalize relations to IDs.
:param relations: A collection of either integer identifiers for relations or
string labels for relations (that will get auto-converted)
:returns: Integer identifiers for relations
:raises ValueError: If the ``relations`` passed are string labels
and this triples factory does not have a relation label to identifier mapping
(e.g., it's just a base :class:`CoreTriplesFactory` instance)
"""
for e in relations:
if not isinstance(e, int):
raise ValueError(f"{self.__class__.__name__} cannot convert relation IDs from {type(e)} to int.")
return cast(Collection[int], relations)
def get_mask_for_relations(
self,
relations: Collection[int],
invert: bool = False,
) -> torch.BoolTensor:
"""Get a boolean mask for triples with the given relations."""
return _get_triple_mask(
ids=relations,
triples=self.mapped_triples,
columns=1,
invert=invert,
max_id=self.num_relations,
)
def tensor_to_df(
self,
tensor: torch.LongTensor,
**kwargs: Union[torch.Tensor, np.ndarray, Sequence],
) -> pd.DataFrame:
"""Take a tensor of triples and make a pandas dataframe with labels.
:param tensor: shape: (n, 3)
The triples, ID-based and in format (head_id, relation_id, tail_id).
:param kwargs:
Any additional number of columns. Each column needs to be of shape (n,). Reserved column names:
{"head_id", "head_label", "relation_id", "relation_label", "tail_id", "tail_label"}.
:return:
A dataframe with n rows, and 6 + len(kwargs) columns.
"""
return tensor_to_df(tensor=tensor, **kwargs)
def new_with_restriction(
self,
entities: Union[None, Collection[int], Collection[str]] = None,
relations: Union[None, Collection[int], Collection[str]] = None,
invert_entity_selection: bool = False,
invert_relation_selection: bool = False,
) -> "CoreTriplesFactory":
"""Make a new triples factory only keeping the given entities and relations, but keeping the ID mapping.
:param entities:
The entities of interest. If None, defaults to all entities.
:param relations:
The relations of interest. If None, defaults to all relations.
:param invert_entity_selection:
Whether to invert the entity selection, i.e. select those triples without the provided entities.
:param invert_relation_selection:
Whether to invert the relation selection, i.e. select those triples without the provided relations.
:return:
A new triples factory, which has only a subset of the triples containing the entities and relations of
interest. The label-to-ID mapping is *not* modified.
"""
# prepare metadata
extra_metadata = {}
if entities is not None:
extra_metadata["entity_restriction"] = entities
entities = self.entities_to_ids(entities=entities)
remaining_entities = (self.num_entities - len(entities)) if invert_entity_selection else len(entities)
logger.info(f"keeping {format_relative_comparison(remaining_entities, self.num_entities)} entities.")
if relations is not None:
extra_metadata["relation_restriction"] = relations
relations = self.relations_to_ids(relations=relations)
remaining_relations = (self.num_relations - len(relations)) if invert_relation_selection else len(relations)
logger.info(f"keeping {format_relative_comparison(remaining_relations, self.num_relations)} relations.")
# Delegate to function
mapped_triples = restrict_triples(
mapped_triples=self.mapped_triples,
entities=entities,
relations=relations,
invert_entity_selection=invert_entity_selection,
invert_relation_selection=invert_relation_selection,
)
# restrict triples can only remove triples; thus, if the new size equals the old one, nothing has changed
if mapped_triples.shape[0] == self.num_triples:
return self
logger.info(f"keeping {format_relative_comparison(mapped_triples.shape[0], self.num_triples)} triples.")
return self.clone_and_exchange_triples(
mapped_triples=mapped_triples,
extra_metadata=extra_metadata,
)
@classmethod
# docstr-coverage: inherited
def from_path_binary(
cls,
path: Union[str, pathlib.Path, TextIO],
) -> "CoreTriplesFactory": # noqa: D102
"""
Load triples factory from a binary file.
:param path:
The path, pointing to an existing PyTorch .pt file.
:return:
The loaded triples factory.
"""
path = normalize_path(path)
logger.info(f"Loading from {path.as_uri()}")
return cls(**cls._from_path_binary(path=path))
@classmethod
def _from_path_binary(
cls,
path: pathlib.Path,
) -> MutableMapping[str, Any]:
# load base
data = dict(torch.load(path.joinpath(cls.base_file_name)))
# load numeric triples
data["mapped_triples"] = torch.as_tensor(
pd.read_csv(path.joinpath(cls.triples_file_name), sep="\t", dtype=int).values,
dtype=torch.long,
)
return data
def to_path_binary(
self,
path: Union[str, pathlib.Path, TextIO],
) -> pathlib.Path:
"""
Save triples factory to path in (PyTorch's .pt) binary format.
:param path:
The path to store the triples factory to.
:returns:
The path to the file that got dumped
"""
path = normalize_path(path, mkdir=True)
# store numeric triples
pd.DataFrame(
data=self.mapped_triples.numpy(),
columns=COLUMN_LABELS,
).to_csv(path.joinpath(self.triples_file_name), sep="\t", index=False)
# store metadata
torch.save(self._get_binary_state(), path.joinpath(self.base_file_name))
logger.info(f"Stored {self} to {path.as_uri()}")
return path
def _get_binary_state(self):
return dict(
num_entities=self.num_entities,
# note: num_relations will be doubled again when instantiating with create_inverse_triples=True
num_relations=self.real_num_relations,
create_inverse_triples=self.create_inverse_triples,
metadata=self.metadata,
)
class TriplesFactory(CoreTriplesFactory):
"""Create instances given the path to triples."""
file_name_entity_to_id: ClassVar[str] = "entity_to_id"
file_name_relation_to_id: ClassVar[str] = "relation_to_id"
def __init__(
self,
mapped_triples: MappedTriples,
entity_to_id: EntityMapping,
relation_to_id: RelationMapping,
create_inverse_triples: bool = False,
metadata: Optional[Mapping[str, Any]] = None,
):
"""
Create the triples factory.
:param mapped_triples: shape: (n, 3)
A three-column matrix where each row are the head identifier, relation identifier, then tail identifier.
:param entity_to_id:
The mapping from entities' labels to their indices.
:param relation_to_id:
The mapping from relations' labels to their indices.
:param create_inverse_triples:
Whether to create inverse triples.
:param metadata:
Arbitrary metadata to go with the graph
"""
self.entity_labeling = Labeling(label_to_id=entity_to_id)
self.relation_labeling = Labeling(label_to_id=relation_to_id)
super().__init__(
mapped_triples=mapped_triples,
num_entities=self.entity_labeling.max_id,
num_relations=self.relation_labeling.max_id,
create_inverse_triples=create_inverse_triples,
metadata=metadata,
)
@classmethod
def from_labeled_triples(
cls,
triples: LabeledTriples,
*,
create_inverse_triples: bool = False,
entity_to_id: Optional[EntityMapping] = None,
relation_to_id: Optional[RelationMapping] = None,
compact_id: bool = True,
filter_out_candidate_inverse_relations: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> "TriplesFactory":
"""
Create a new triples factory from label-based triples.
:param triples: shape: (n, 3), dtype: str
The label-based triples.
:param create_inverse_triples:
Whether to create inverse triples.
:param entity_to_id:
The mapping from entity labels to ID. If None, create a new one from the triples.
:param relation_to_id:
The mapping from relations labels to ID. If None, create a new one from the triples.
:param compact_id:
Whether to compact IDs such that the IDs are consecutive.
:param filter_out_candidate_inverse_relations:
Whether to remove triples with relations with the inverse suffix.
:param metadata:
Arbitrary key/value pairs to store as metadata
:return:
A new triples factory.
"""
# Check if the triples are inverted already
# We re-create them pure index based to ensure that _all_ inverse triples are present and that they are
# contained if and only if create_inverse_triples is True.
if filter_out_candidate_inverse_relations:
unique_relations, inverse = np.unique(triples[:, 1], return_inverse=True)
suspected_to_be_inverse_relations = {r for r in unique_relations if r.endswith(INVERSE_SUFFIX)}
if len(suspected_to_be_inverse_relations) > 0:
logger.warning(
f"Some triples already have the inverse relation suffix {INVERSE_SUFFIX}. "
f"Re-creating inverse triples to ensure consistency. You may disable this behaviour by passing "
f"filter_out_candidate_inverse_relations=False",
)
relation_ids_to_remove = [
i for i, r in enumerate(unique_relations.tolist()) if r in suspected_to_be_inverse_relations
]
mask = np.isin(element=inverse, test_elements=relation_ids_to_remove, invert=True)
logger.info(f"keeping {mask.sum() / mask.shape[0]} triples.")
triples = triples[mask]
# Generate entity mapping if necessary
if entity_to_id is None:
entity_to_id = create_entity_mapping(triples=triples)
if compact_id:
entity_to_id = compact_mapping(mapping=entity_to_id)[0]
# Generate relation mapping if necessary
if relation_to_id is None:
relation_to_id = create_relation_mapping(triples[:, 1])
if compact_id:
relation_to_id = compact_mapping(mapping=relation_to_id)[0]
# Map triples of labels to triples of IDs.
mapped_triples = _map_triples_elements_to_ids(
triples=triples,
entity_to_id=entity_to_id,
relation_to_id=relation_to_id,
)
return cls(
entity_to_id=entity_to_id,
relation_to_id=relation_to_id,
mapped_triples=mapped_triples,
create_inverse_triples=create_inverse_triples,
metadata=metadata,
)
@classmethod
def from_path(
cls,
path: Union[str, pathlib.Path, TextIO],
*,
create_inverse_triples: bool = False,
entity_to_id: Optional[EntityMapping] = None,
relation_to_id: Optional[RelationMapping] = None,
compact_id: bool = True,
metadata: Optional[Dict[str, Any]] = None,
load_triples_kwargs: Optional[Mapping[str, Any]] = None,
**kwargs,
) -> "TriplesFactory":
"""
Create a new triples factory from triples stored in a file.
:param path:
The path where the label-based triples are stored.
:param create_inverse_triples:
Whether to create inverse triples.
:param entity_to_id:
The mapping from entity labels to ID. If None, create a new one from the triples.