-
-
Notifications
You must be signed in to change notification settings - Fork 179
/
complex.py
111 lines (91 loc) 路 4.26 KB
/
complex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*-
"""Implementation of the ComplEx model."""
from typing import Optional
import torch.nn as nn
from ..base import SimpleVectorEntityRelationEmbeddingModel
from ...losses import Loss, SoftplusLoss
from ...nn.modules import ComplExInteractionFunction
from ...regularizers import LpRegularizer, Regularizer
from ...triples import TriplesFactory
from ...typing import DeviceHint
__all__ = [
'ComplEx',
]
class ComplEx(SimpleVectorEntityRelationEmbeddingModel):
r"""An implementation of ComplEx [trouillon2016]_.
ComplEx is an extension of :class:`pykeen.models.DistMult` that uses complex valued representations for the
entities and relations. Entities and relations are represented as vectors
$\textbf{e}_i, \textbf{r}_i \in \mathbb{C}^d$, and the plausibility score is computed using the
Hadamard product:
.. math::
f(h,r,t) = Re(\mathbf{e}_h\odot\mathbf{r}_r\odot\mathbf{e}_t)
Which expands to:
.. math::
f(h,r,t) = \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle
+ \left\langle Im(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle
+ \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle
- \left\langle Im(\mathbf{e}_h),Im(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle
where $Re(\textbf{x})$ and $Im(\textbf{x})$ denote the real and imaginary parts of the complex valued vector
$\textbf{x}$. Because the Hadamard product is not commutative in the complex space, ComplEx can model
anti-symmetric relations in contrast to DistMult.
.. seealso ::
Official implementation: https://github.com/ttrouill/complex/
"""
#: The default strategy for optimizing the model's hyper-parameters
hpo_default = dict(
embedding_dim=dict(type=int, low=50, high=300, q=50),
)
#: The default loss function class
loss_default = SoftplusLoss
#: The default parameters for the default loss function class
loss_default_kwargs = dict(reduction='mean')
#: The regularizer used by [trouillon2016]_ for ComplEx.
regularizer_default = LpRegularizer
#: The LP settings used by [trouillon2016]_ for ComplEx.
regularizer_default_kwargs = dict(
weight=0.01,
p=2.0,
normalize=True,
)
def __init__(
self,
triples_factory: TriplesFactory,
embedding_dim: int = 200,
automatic_memory_optimization: Optional[bool] = None,
loss: Optional[Loss] = None,
preferred_device: DeviceHint = None,
random_seed: Optional[int] = None,
regularizer: Optional[Regularizer] = None,
) -> None:
"""Initialize ComplEx.
:param triples_factory: TriplesFactory
The triple factory connected to the model.
:param embedding_dim:
The embedding dimensionality of the entity embeddings.
:param automatic_memory_optimization: bool
Whether to automatically optimize the sub-batch size during training and batch size during evaluation with
regards to the hardware at hand.
:param loss: OptionalLoss (optional)
The loss to use. Defaults to SoftplusLoss.
:param preferred_device: str (optional)
The default device where to model is located.
:param random_seed: int (optional)
An optional random seed to set before the initialization of weights.
:param regularizer: BaseRegularizer
The regularizer to use.
"""
interaction_function = ComplExInteractionFunction()
super().__init__(
triples_factory=triples_factory,
interaction_function=interaction_function,
embedding_dim=2 * embedding_dim, # complex embeddings
automatic_memory_optimization=automatic_memory_optimization,
loss=loss,
preferred_device=preferred_device,
random_seed=random_seed,
regularizer=regularizer,
# initialize with entity and relation embeddings with standard normal distribution, cf.
# https://github.com/ttrouill/complex/blob/dc4eb93408d9a5288c986695b58488ac80b1cc17/efe/models.py#L481-L487
entity_initializer=nn.init.normal_,
relation_initializer=nn.init.normal_,
)