-
Notifications
You must be signed in to change notification settings - Fork 124
/
lookup_embedder.py
177 lines (151 loc) · 6.6 KB
/
lookup_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from torch import Tensor
import torch.nn
import torch.nn.functional
from kge import Config, Dataset
from kge.job import Job
from kge.model import KgeEmbedder
from kge.misc import round_to_points
from typing import List
class LookupEmbedder(KgeEmbedder):
def __init__(
self,
config: Config,
dataset: Dataset,
configuration_key: str,
vocab_size: int,
init_for_load_only=False,
):
super().__init__(
config, dataset, configuration_key, init_for_load_only=init_for_load_only
)
# read config
self.normalize_p = self.get_option("normalize.p")
self.space = self.check_option("space", ["euclidean", "complex"])
# n3 is only accepted when space is complex
if self.space == "complex":
self.regularize = self.check_option("regularize", ["", "lp", "n3"])
else:
self.regularize = self.check_option("regularize", ["", "lp"])
self.sparse = self.get_option("sparse")
self.config.check("train.trace_level", ["batch", "epoch"])
self.vocab_size = vocab_size
round_embedder_dim_to = self.get_option("round_dim_to")
if len(round_embedder_dim_to) > 0:
self.dim = round_to_points(round_embedder_dim_to, self.dim)
self._embeddings = torch.nn.Embedding(
self.vocab_size, self.dim, sparse=self.sparse,
)
if not init_for_load_only:
# initialize weights
self.initialize(self._embeddings.weight.data)
self._normalize_embeddings()
# TODO handling negative dropout because using it with ax searches for now
dropout = self.get_option("dropout")
if dropout < 0:
if config.get("job.auto_correct"):
config.log(
"Setting {}.dropout to 0., "
"was set to {}.".format(configuration_key, dropout)
)
dropout = 0
self.dropout = torch.nn.Dropout(dropout)
def _normalize_embeddings(self):
if self.normalize_p > 0:
with torch.no_grad():
self._embeddings.weight.data = torch.nn.functional.normalize(
self._embeddings.weight.data, p=self.normalize_p, dim=-1
)
def prepare_job(self, job: Job, **kwargs):
from kge.job import TrainingJob
super().prepare_job(job, **kwargs)
if self.normalize_p > 0 and isinstance(job, TrainingJob):
# just to be sure it's right initially
job.pre_run_hooks.append(lambda job: self._normalize_embeddings())
# normalize after each batch
job.post_batch_hooks.append(lambda job: self._normalize_embeddings())
@torch.no_grad()
def init_pretrained(self, pretrained_embedder: KgeEmbedder) -> None:
(
self_intersect_ind,
pretrained_intersect_ind,
) = self._intersect_ids_with_pretrained_embedder(pretrained_embedder)
self._embeddings.weight[
torch.from_numpy(self_intersect_ind)
.to(self._embeddings.weight.device)
.long()
] = pretrained_embedder.embed(torch.from_numpy(pretrained_intersect_ind)).to(
self._embeddings.weight.device
)
def embed(self, indexes: Tensor) -> Tensor:
return self._postprocess(self._embeddings(indexes.long()))
def embed_all(self) -> Tensor:
return self._postprocess(self._embeddings_all())
def _postprocess(self, embeddings: Tensor) -> Tensor:
if self.dropout.p > 0:
embeddings = self.dropout(embeddings)
return embeddings
def _embeddings_all(self) -> Tensor:
return self._embeddings(
torch.arange(
self.vocab_size, dtype=torch.long, device=self._embeddings.weight.device
)
)
def _get_regularize_weight(self) -> Tensor:
return self.get_option("regularize_weight")
def _abs_complex(self, parameters) -> Tensor:
parameters_re, parameters_im = (t.contiguous() for t in parameters.chunk(2, dim=1))
parameters = torch.sqrt(parameters_re ** 2 + parameters_im ** 2 + 1e-14) # + 1e-14 to avoid NaN: https://github.com/lilanxiao/Rotated_IoU/issues/20
return parameters
def penalty(self, **kwargs) -> List[Tensor]:
# TODO factor out to a utility method
result = super().penalty(**kwargs)
if self.regularize == "" or self.get_option("regularize_weight") == 0.0:
pass
elif self.regularize in ["lp", 'n3']:
if self.regularize == "n3":
p = 3
else:
p = (
self.get_option("regularize_args.p")
if self.has_option("regularize_args.p")
else 2
)
regularize_weight = self._get_regularize_weight()
if not self.get_option("regularize_args.weighted"):
# unweighted Lp regularization
parameters = self._embeddings_all()
if self.regularize == "n3" and self.space == 'complex':
parameters = self._abs_complex(parameters)
result += [
(
f"{self.configuration_key}.L{p}_penalty",
(regularize_weight / p * parameters.norm(p=p) ** p).sum(),
)
]
else:
# weighted Lp regularization
unique_indexes, counts = torch.unique(
kwargs["indexes"], return_counts=True
)
parameters = self._embeddings(unique_indexes)
if self.regularize == "n3" and self.space == 'complex':
parameters = self._abs_complex(parameters)
if (p % 2 == 1) and (self.regularize != "n3"):
parameters = torch.abs(parameters)
result += [
(
f"{self.configuration_key}.L{p}_penalty",
(
regularize_weight
/ p
* (parameters ** p * counts.float().view(-1, 1))
).sum()
# In contrast to unweighted Lp regularization, rescaling by
# number of triples/indexes is necessary here so that penalty
# term is correct in expectation
/ len(kwargs["indexes"]),
)
]
else: # unknown regularization
raise ValueError(f"Invalid value regularize={self.regularize}")
return result