forked from flairNLP/flair
-
Notifications
You must be signed in to change notification settings - Fork 0
/
similarity_learning_model.py
386 lines (311 loc) · 13.7 KB
/
similarity_learning_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
from abc import abstractmethod
import flair
from flair.data import DataPoint, DataPair
from flair.embeddings import Embeddings
from flair.datasets import DataLoader
from flair.training_utils import Result
from flair.training_utils import store_embeddings
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import itertools
from typing import Union, List
from pathlib import Path
# == similarity measures ==
class SimilarityMeasure:
@abstractmethod
def forward(self, x):
pass
# helper class for ModelSimilarity
class SliceReshaper(flair.nn.Model):
def __init__(self, begin, end=None, shape=None):
super(SliceReshaper, self).__init__()
self.begin = begin
self.end = end
self.shape = shape
def forward(self, x):
x = x[:, self.begin] if self.end is None else x[:, self.begin : self.end]
x = x.view(-1, *self.shape) if self.shape is not None else x
return x
# -- works with binary cross entropy loss --
class ModelSimilarity(SimilarityMeasure):
"""
Similarity defined by the model. The model parameters are given by the first element of the pair.
The similarity is evaluated by doing the forward pass (inference) on the parametrized model with
the second element of the pair as input.
"""
def __init__(self, model):
# model is a list of tuples (function, parameters), where parameters is a dict {param_name: param_extract_model}
self.model = model
def forward(self, x):
model_parameters = x[0]
model_inputs = x[1]
cur_outputs = model_inputs
for layer_model, parameter_map in self.model:
param_dict = {}
for param_name, param_slice_reshape in parameter_map.items():
if isinstance(param_slice_reshape, SliceReshaper):
val = param_slice_reshape(model_parameters)
else:
val = param_slice_reshape
param_dict[param_name] = val
cur_outputs = layer_model(cur_outputs, **param_dict)
return cur_outputs
# -- works with ranking/triplet loss --
class CosineSimilarity(SimilarityMeasure):
"""
Similarity defined by the cosine distance.
"""
def forward(self, x):
input_modality_0 = x[0]
input_modality_1 = x[1]
# normalize the embeddings
input_modality_0_norms = torch.norm(input_modality_0, dim=-1, keepdim=True)
input_modality_1_norms = torch.norm(input_modality_1, dim=-1, keepdim=True)
return torch.matmul(
input_modality_0 / input_modality_0_norms,
(input_modality_1 / input_modality_1_norms).t(),
)
# == similarity losses ==
class SimilarityLoss(nn.Module):
def __init__(self):
super(SimilarityLoss, self).__init__()
@abstractmethod
def forward(self, inputs, targets):
pass
class PairwiseBCELoss(SimilarityLoss):
"""
Binary cross entropy between pair similarities and pair labels.
"""
def __init__(self, balanced=False):
super(PairwiseBCELoss, self).__init__()
self.balanced = balanced
def forward(self, inputs, targets):
n = inputs.shape[0]
neg_targets = torch.ones_like(targets).to(flair.device) - targets
# we want that logits for corresponding pairs are high, and for non-corresponding low
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
if self.balanced:
# TODO: this assumes eye matrix
weight_matrix = n * (targets / 2.0 + neg_targets / (2.0 * (n - 1)))
bce_loss *= weight_matrix
loss = bce_loss.mean()
return loss
class RankingLoss(SimilarityLoss):
"""
Triplet ranking loss between pair similarities and pair labels.
"""
def __init__(self, margin=0.1, direction_weights=[0.5, 0.5]):
super(RankingLoss, self).__init__()
self.margin = margin
self.direction_weights = direction_weights
def forward(self, inputs, targets):
n = inputs.shape[0]
neg_targets = torch.ones_like(targets) - targets
# loss matrices for two directions of alignment, from modality 0 => modality 1 and vice versa
ranking_loss_matrix_01 = neg_targets * F.relu(
self.margin + inputs - torch.diag(inputs).view(n, 1)
)
ranking_loss_matrix_10 = neg_targets * F.relu(
self.margin + inputs - torch.diag(inputs).view(1, n)
)
neg_targets_01_sum = torch.sum(neg_targets, dim=1)
neg_targets_10_sum = torch.sum(neg_targets, dim=0)
loss = self.direction_weights[0] * torch.mean(
torch.sum(ranking_loss_matrix_01 / neg_targets_01_sum, dim=1)
) + self.direction_weights[1] * torch.mean(
torch.sum(ranking_loss_matrix_10 / neg_targets_10_sum, dim=0)
)
return loss
# == similarity learner ==
class SimilarityLearner(flair.nn.Model):
def __init__(
self,
source_embeddings: Embeddings,
target_embeddings: Embeddings,
similarity_measure: SimilarityMeasure,
similarity_loss: SimilarityLoss,
eval_device=flair.device,
source_mapping: torch.nn.Module = None,
target_mapping: torch.nn.Module = None,
recall_at_points: List[int] = [1, 5, 10, 20],
recall_at_points_weights: List[float] = [0.4, 0.3, 0.2, 0.1],
):
super(SimilarityLearner, self).__init__()
self.source_embeddings: Embeddings = source_embeddings
self.target_embeddings: Embeddings = target_embeddings
self.source_mapping: torch.nn.Module = source_mapping
self.target_mapping: torch.nn.Module = target_mapping
self.similarity_measure: SimilarityMeasure = similarity_measure
self.similarity_loss: SimilarityLoss = similarity_loss
self.eval_device = eval_device
self.recall_at_points: List[int] = recall_at_points
self.recall_at_points_weights: List[float] = recall_at_points_weights
self.to(flair.device)
def _embed_source(self, data_points):
if type(data_points[0]) == DataPair:
data_points = [point.first for point in data_points]
self.source_embeddings.embed(data_points)
source_embedding_tensor = torch.stack(
[point.embedding for point in data_points]
).to(flair.device)
if self.source_mapping is not None:
source_embedding_tensor = self.source_mapping(source_embedding_tensor)
return source_embedding_tensor
def _embed_target(self, data_points):
if type(data_points[0]) == DataPair:
data_points = [point.second for point in data_points]
self.target_embeddings.embed(data_points)
target_embedding_tensor = torch.stack(
[point.embedding for point in data_points]
).to(flair.device)
if self.target_mapping is not None:
target_embedding_tensor = self.target_mapping(target_embedding_tensor)
return target_embedding_tensor
def get_similarity(self, modality_0_embeddings, modality_1_embeddings):
"""
:param modality_0_embedding: embeddings of first modality, a tensor of shape [n0, d0]
:param modality_1_embeddings: embeddings of second modality, a tensor of shape [n1, d1]
:return: a similarity matrix of shape [n0, n1]
"""
return self.similarity_measure.forward(
[modality_0_embeddings, modality_1_embeddings]
)
def forward_loss(
self, data_points: Union[List[DataPoint], DataPoint]
) -> torch.tensor:
mapped_source_embeddings = self._embed_source(data_points)
mapped_target_embeddings = self._embed_target(data_points)
similarity_matrix = self.similarity_measure.forward(
(mapped_source_embeddings, mapped_target_embeddings)
)
def add_to_index_map(hashmap, key, val):
if key not in hashmap:
hashmap[key] = [val]
else:
hashmap[key] += [val]
index_map = {"first": {}, "second": {}}
for data_point_id, data_point in enumerate(data_points):
add_to_index_map(index_map["first"], str(data_point.first), data_point_id)
add_to_index_map(index_map["second"], str(data_point.second), data_point_id)
targets = torch.zeros_like(similarity_matrix).to(flair.device)
for data_point in data_points:
first_indices = index_map["first"][str(data_point.first)]
second_indices = index_map["second"][str(data_point.second)]
for first_index, second_index in itertools.product(
first_indices, second_indices
):
targets[first_index, second_index] = 1.0
loss = self.similarity_loss(similarity_matrix, targets)
return loss
def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode="none",
) -> (Result, float):
# assumes that for each data pair there's at least one embedding per modality
with torch.no_grad():
# pre-compute embeddings for all targets in evaluation dataset
target_index = {}
all_target_embeddings = []
for data_points in data_loader:
target_inputs = []
for data_point in data_points:
if str(data_point.second) not in target_index:
target_index[str(data_point.second)] = len(target_index)
target_inputs.append(data_point)
if target_inputs:
all_target_embeddings.append(
self._embed_target(target_inputs).to(self.eval_device)
)
store_embeddings(data_points, embeddings_storage_mode)
all_target_embeddings = torch.cat(all_target_embeddings, dim=0) # [n0, d0]
assert len(target_index) == all_target_embeddings.shape[0]
ranks = []
for data_points in data_loader:
batch_embeddings = self._embed_source(data_points)
batch_source_embeddings = batch_embeddings.to(self.eval_device)
# compute the similarity
batch_similarity_matrix = self.similarity_measure.forward(
[batch_source_embeddings, all_target_embeddings]
)
# sort the similarity matrix across modality 1
batch_modality_1_argsort = torch.argsort(
batch_similarity_matrix, descending=True, dim=1
)
# get the ranks, so +1 to start counting ranks from 1
batch_modality_1_ranks = (
torch.argsort(batch_modality_1_argsort, dim=1) + 1
)
batch_target_indices = [
target_index[str(data_point.second)] for data_point in data_points
]
batch_gt_ranks = batch_modality_1_ranks[
torch.arange(batch_similarity_matrix.shape[0]),
torch.tensor(batch_target_indices),
]
ranks.extend(batch_gt_ranks.tolist())
store_embeddings(data_points, embeddings_storage_mode)
ranks = np.array(ranks)
median_rank = np.median(ranks)
recall_at = {k: np.mean(ranks <= k) for k in self.recall_at_points}
results_header = ["Median rank"] + [
"Recall@top" + str(r) for r in self.recall_at_points
]
results_header_str = "\t".join(results_header)
epoch_results = [str(median_rank)] + [
str(recall_at[k]) for k in self.recall_at_points
]
epoch_results_str = "\t".join(epoch_results)
detailed_results = ", ".join(
[f"{h}={v}" for h, v in zip(results_header, epoch_results)]
)
validated_measure = sum(
[
recall_at[r] * w
for r, w in zip(self.recall_at_points, self.recall_at_points_weights)
]
)
return (
Result(
validated_measure,
results_header_str,
epoch_results_str,
detailed_results,
),
0,
)
def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
"input_modality_0_embedding": self.source_embeddings,
"input_modality_1_embedding": self.target_embeddings,
"similarity_measure": self.similarity_measure,
"similarity_loss": self.similarity_loss,
"source_mapping": self.source_mapping,
"target_mapping": self.target_mapping,
"eval_device": self.eval_device,
"recall_at_points": self.recall_at_points,
"recall_at_points_weights": self.recall_at_points_weights,
}
return model_state
def _init_model_with_state_dict(state):
# The conversion from old model's constructor interface
if "input_embeddings" in state:
state["input_modality_0_embedding"] = state["input_embeddings"][0]
state["input_modality_1_embedding"] = state["input_embeddings"][1]
model = SimilarityLearner(
source_embeddings=state["input_modality_0_embedding"],
target_embeddings=state["input_modality_1_embedding"],
source_mapping=state["source_mapping"],
target_mapping=state["target_mapping"],
similarity_measure=state["similarity_measure"],
similarity_loss=state["similarity_loss"],
eval_device=state["eval_device"],
recall_at_points=state["recall_at_points"],
recall_at_points_weights=state["recall_at_points_weights"],
)
model.load_state_dict(state["state_dict"])
return model