-
-
Notifications
You must be signed in to change notification settings - Fork 179
/
__init__.py
105 lines (92 loc) 路 2.48 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
r"""An interaction model $f:\mathcal{E} \times \mathcal{R} \times \mathcal{E} \rightarrow \mathbb{R}$ computes a
real-valued score representing the plausibility of a triple $(h,r,t) \in \mathbb{K}$ given the embeddings for the
entities and relations. In general, a larger score indicates a higher plausibility. The interpretation of the
score value is model-dependent, and usually it cannot be directly interpreted as a probability.
""" # noqa: D205, D400
from typing import Mapping, Set, Type, Union
from .base import ( # noqa:F401
DoubleRelationEmbeddingModel, ERModel, LiteralModel, Model, SingleVectorEmbeddingModel,
TwoSideEmbeddingModel, TwoVectorEmbeddingModel,
)
from .multimodal import ComplExLiteral, DistMultLiteral
from .unimodal import (
ComplEx,
ConvE,
ConvKB,
DistMult,
ERMLP,
ERMLPE,
HolE,
KG2E,
NTN,
ProjE,
RESCAL,
RGCN,
RotatE,
SimplE,
StructuredEmbedding,
TransD,
TransE,
TransH,
TransR,
TuckER,
UnstructuredModel,
)
from ..utils import get_cls, normalize_string
__all__ = [
'ComplEx',
'ComplExLiteral',
'ConvE',
'ConvKB',
'DistMult',
'DistMultLiteral',
'ERMLP',
'ERMLPE',
'HolE',
'KG2E',
'NTN',
'ProjE',
'RESCAL',
'RGCN',
'RotatE',
'SimplE',
'StructuredEmbedding',
'TransD',
'TransE',
'TransH',
'TransR',
'TuckER',
'UnstructuredModel',
'models',
'get_model_cls',
]
_BASE_MODELS = {
ERModel,
SingleVectorEmbeddingModel,
DoubleRelationEmbeddingModel,
TwoSideEmbeddingModel,
TwoVectorEmbeddingModel,
LiteralModel,
}
def _concrete_subclasses(cls: Type[Model]):
for subcls in cls.__subclasses__():
if not subcls._is_abstract and subcls not in _BASE_MODELS:
yield subcls
yield from _concrete_subclasses(subcls)
_MODELS: Set[Type[Model]] = set(_concrete_subclasses(Model)) # type: ignore
#: A mapping of models' names to their implementations
models: Mapping[str, Type[Model]] = {
normalize_string(cls.__name__): cls
for cls in _MODELS
}
def get_model_cls(query: Union[str, Type[Model]]) -> Type[Model]:
"""Look up a model class by name (case/punctuation insensitive) in :data:`pykeen.models.models`.
:param query: The name of the model (case insensitive, punctuation insensitive).
:return: The model class
"""
return get_cls(
query,
base=Model, # type: ignore
lookup_dict=models,
)