-
Notifications
You must be signed in to change notification settings - Fork 240
/
collaborative_experts.py
1250 lines (1107 loc) · 54.7 KB
/
collaborative_experts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Built on top of the original implementation at https://github.com/albanie/collaborative-experts
#
# Modifications by Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Dict
from towhee.models.collaborative_experts.util import expert_tensor_storage
from towhee.models.collaborative_experts.net_vlad import NetVLAD
from torch.autograd import Variable
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
import itertools
class Mish(nn.Module):
"""
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
SRC: https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py
"""
def forward(self, input_):
"""
Forward pass of the function.
"""
return input_ * torch.tanh(F.softplus(input_))
def kronecker_prod(t1, t2):
# kronecker is performed along the last dim
kron = torch.bmm(t1.view(-1, t1.size(-1), 1), t2.contiguous().view(-1, 1, t2.size(-1)))
return kron.view(t1.shape[0], t1.shape[1], -1)
def drop_nans(x, ind, validate_missing):
"""Remove nans, which we expect to find at missing indices.
Args:
x (torch.Tensor): features
ind (torch.Tensor): binary values denoting whether or not a given feature is
present
validate_missing (bool): whether to validate that the missing location contains
a nan.
Returns:
(torch.tensor): the features, with the missing values masked to zero.
"""
missing = torch.nonzero(ind == 0).flatten()
if missing.numel():
if validate_missing:
vals = x[missing[0]]
assert vals.view(-1)[0], "expected nans at missing locations"
x_ = x
x_[missing] = 0
x = x_
return x
class CENet(nn.Module):
"""
Collaborative Experts Module.
"""
def __init__(
self,
task,
use_ce,
text_dim,
l2renorm,
expert_dims,
vlad_clusters,
ghost_clusters,
disable_nan_checks,
keep_missing_modalities,
test_caption_mode,
randomise_feats,
feat_aggregation,
ce_shared_dim,
trn_config,
trn_cat,
include_self,
use_mish,
use_bn_reason,
num_h_layers,
num_g_layers,
kron_dets=False,
freeze_weights=False,
geometric_mlp=False,
rand_proj=False,
mimic_ce_dims=False,
coord_dets=False,
concat_experts=False,
spatial_feats=False,
concat_mix_experts=False,
verbose=False,
num_classes=None):
super().__init__()
self.l2renorm = l2renorm
self.task = task
self.geometric_mlp = geometric_mlp
self.feat_aggregation = feat_aggregation
self.expert_dims = expert_dims
self.num_h_layers = num_h_layers
self.num_g_layers = num_g_layers
self.use_mish = use_mish
self.use_bn_resaon = use_bn_reason
self.include_self = include_self
self.kron_dets = kron_dets
self.rand_proj = rand_proj
self.coord_dets = coord_dets
self.disable_nan_checks = disable_nan_checks
self.trn_config = trn_config
self.trn_cat = trn_cat
if randomise_feats:
self.random_feats = set(x for x in randomise_feats.split(","))
else:
self.random_feats = set()
# sanity checks on the features that may be vladded
pre_vlad_feat_sizes = {"ocr": 300, "audio": 128, "speech": 300}
pre_vlad_feat_sizes = {key: val for key, val in pre_vlad_feat_sizes.items()
if feat_aggregation[key]["temporal"] == "vlad"}
# we basically disable safety checks for detection-sem
if spatial_feats:
spatial_feat_dim = 16
else:
spatial_feat_dim = 5
if self.geometric_mlp:
self.geometric_mlp_model = SpatialMLP(spatial_feat_dim)
if kron_dets:
sem_det_dim = 300 * spatial_feat_dim
elif coord_dets:
sem_det_dim = spatial_feat_dim
elif rand_proj:
sem_det_dim = 300 + 300
self.proj = nn.Linear(spatial_feat_dim, 300)
else:
sem_det_dim = 300 + spatial_feat_dim
self.spatial_feat_dim = spatial_feat_dim
pre_vlad_feat_sizes["detection-sem"] = sem_det_dim
if "detection-sem" in expert_dims:
new_in_dim = sem_det_dim * vlad_clusters["detection-sem"]
expert_dims["detection-sem"] = (new_in_dim, expert_dims["detection-sem"][1])
vlad_feat_sizes = dict(vlad_clusters.items())
self.pooling = nn.ModuleDict()
for mod, expected in pre_vlad_feat_sizes.items():
if mod in expert_dims.keys():
feature_size = expert_dims[mod][0] // vlad_clusters[mod]
msg = f"expected {expected} for {mod} features atm"
assert feature_size == expected, msg
self.pooling[mod] = NetVLAD(
feature_size=feature_size,
cluster_size=vlad_clusters[mod],
)
if "retrieval" in self.task:
if vlad_clusters["text"] == 0:
self.text_pooling = nn.Sequential()
else:
self.text_pooling = NetVLAD(
feature_size=text_dim,
cluster_size=vlad_clusters["text"],
ghost_clusters=ghost_clusters["text"],
)
text_dim = self.text_pooling.out_dim
else:
self.num_classes = num_classes
text_dim = None
self.tensor_storage = expert_tensor_storage(
experts=self.expert_dims.keys(),
feat_aggregation=self.feat_aggregation,
)
self.ce = CEModule(
use_ce=use_ce,
task=self.task,
verbose=verbose,
l2renorm=l2renorm,
trn_cat=self.trn_cat,
trn_config=self.trn_config,
random_feats=self.random_feats,
freeze_weights=freeze_weights,
text_dim=text_dim,
test_caption_mode=test_caption_mode,
concat_experts=concat_experts,
concat_mix_experts=concat_mix_experts,
expert_dims=expert_dims,
vlad_feat_sizes=vlad_feat_sizes,
disable_nan_checks=disable_nan_checks,
keep_missing_modalities=keep_missing_modalities,
mimic_ce_dims=mimic_ce_dims,
include_self=include_self,
use_mish=use_mish,
use_bn_reason=use_bn_reason,
num_h_layers=num_h_layers,
num_g_layers=num_g_layers,
num_classes=num_classes,
same_dim=ce_shared_dim,
)
def randomise_feats(self, experts, key):
if key in self.random_feats:
# keep expected nans
nan_mask = torch.isnan(experts[key])
experts[key] = torch.randn_like(experts[key])
if not self.disable_nan_checks:
nans = torch.tensor(float("nan")) # pylint: disable=not-callable
experts[key][nan_mask] = nans.to(experts[key].device)
return experts
def forward(self, experts, ind, text=None, raw_captions=None, text_token_mask=None):
aggregated_experts = OrderedDict()
if "detection-sem" in self.expert_dims:
det_sem = experts["detection-sem"]
box_feats = det_sem[:, :, :self.spatial_feat_dim]
sem_feats = det_sem[:, :, self.spatial_feat_dim:]
if self.geometric_mlp:
x = box_feats.view(-1, box_feats.shape[-1])
x = self.geometric_mlp_model(x)
box_feats = x.view(box_feats.shape)
if self.kron_dets:
feats = kronecker_prod(box_feats, sem_feats)
elif self.coord_dets:
feats = box_feats.contiguous()
elif self.rand_proj:
feats = box_feats.contiguous()
projected = self.proj(feats)
feats = torch.cat((projected, sem_feats.contiguous()), dim=2)
else:
feats = torch.cat((box_feats, sem_feats.contiguous()), dim=2)
experts["detection-sem"] = feats
# Handle all nan-checks
for mod in self.expert_dims:
experts = self.randomise_feats(experts, mod)
experts[mod] = drop_nans(x=experts[mod], ind=ind[mod], validate_missing=True)
if mod in self.tensor_storage["fixed"]:
aggregated_experts[mod] = experts[mod]
elif mod in self.tensor_storage["variable"]:
aggregated_experts[mod] = self.pooling[mod](experts[mod])
if "retrieval" in self.task:
bb, captions_per_video, max_words, text_feat_dim = text.size()
text = text.view(bb * captions_per_video, max_words, text_feat_dim)
if isinstance(self.text_pooling, NetVLAD):
kwargs = {"mask": text_token_mask}
else:
kwargs = {}
text = self.text_pooling(text, **kwargs)
text = text.view(bb, captions_per_video, -1)
else:
text = None
return self.ce(text, aggregated_experts, ind, raw_captions)
class TemporalAttention(torch.nn.Module):
"""
TemporalAttention Module
"""
def __init__(self, img_feature_dim, num_attention):
super().__init__()
self.weight = Variable(
torch.randn(img_feature_dim, num_attention),
requires_grad=True).cuda() # d*seg
self.img_feature_dim = img_feature_dim
self.num_attention = num_attention
def forward(self, input_):
record = []
input_avg = torch.mean(input_.clone(), dim=1)
input_max = torch.max(input_.clone(), dim=1)
record.append(input_avg)
record.append(input_max[0])
output = torch.matmul(input_, self.weight)
attentions = F.softmax(output, dim=1)
for idx in range(attentions.shape[-1]):
temp = attentions[:, :, idx]
temp_output = torch.sum(temp.unsqueeze(2) * input_, dim=1)
norm = temp_output.norm(p=2, dim=-1, keepdim=True)
temp_output = temp_output.div(norm)
record.append(temp_output)
act_all = torch.cat((record), 1)
return act_all
class RelationModuleMultiScale(torch.nn.Module):
"""
RelationModuleMultiScale Module
"""
# Temporal Relation module in multiply scale, suming over
# [2-frame relation, 3-frame relation, ..., n-frame relation]
def __init__(self, img_feature_dim, num_frames, num_class):
super().__init__()
self.subsample_num = 3 # how many relations selected to sum up
self.img_feature_dim = img_feature_dim
# generate the multiple frame relations
self.scales = list(range(num_frames, 1, -1))
self.relations_scales = []
self.subsample_scales = []
for scale in self.scales:
relations_scale = self.return_relationset(num_frames, scale)
self.relations_scales.append(relations_scale)
# how many samples of relation to select in each forward pass
self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
self.num_class = num_class
self.num_frames = num_frames
num_bottleneck = 256
self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
for i in range(len(self.scales)):
scale = self.scales[i]
fc_fusion = nn.Sequential(
nn.ReLU(),
nn.Linear(scale * self.img_feature_dim, num_bottleneck),
nn.ReLU(),
nn.Linear(num_bottleneck, self.num_class),
)
self.fc_fusion_scales += [fc_fusion]
def forward(self, input_):
# the first one is the largest scale
act_all = input_[:, self.relations_scales[0][0], :]
act_all = act_all.view(act_all.size(0), self.scales[0] * self.img_feature_dim)
act_all = self.fc_fusion_scales[0](act_all)
for scale_id in range(1, len(self.scales)):
# iterate over the scales
idx_relations_randomsample = np.random.choice(
len(self.relations_scales[scale_id]),
self.subsample_scales[scale_id],
replace=False,
)
for idx in idx_relations_randomsample:
act_relation = input_[:, self.relations_scales[scale_id][idx], :]
act_relation = act_relation.view(act_relation.size(0), self.scales[scale_id] * self.img_feature_dim)
act_relation = self.fc_fusion_scales[scale_id](act_relation)
act_all += act_relation
return act_all
def return_relationset(self, num_frames, num_frames_relation):
return list(itertools.combinations(list(range(num_frames)), num_frames_relation))
class RelationModuleMultiScale_Cat(torch.nn.Module): # pylint: disable=invalid-name
"""
RelationModuleMultiScale_Cat Module
"""
# Temporal Relation module in multiply scale, suming over [2-frame relation, 3-frame relation, ..., n-frame relation]
def __init__(self, img_feature_dim, num_frames, num_class):
super().__init__()
self.subsample_num = 3 # how many relations selected to sum up
self.img_feature_dim = img_feature_dim
self.scales = list(range(num_frames, 1, -1)) # generate the multiple frame relations
self.relations_scales = []
self.subsample_scales = []
for scale in self.scales:
relations_scale = self.return_relationset(num_frames, scale)
self.relations_scales.append(relations_scale)
self.subsample_scales.append(min(self.subsample_num,
len(relations_scale))) # how many samples of relation to select in each forward pass
self.num_class = num_class
self.num_frames = num_frames
num_bottleneck = 256
self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
for i in range(len(self.scales)):
scale = self.scales[i]
fc_fusion = nn.Sequential(
nn.ReLU(),
nn.Linear(scale * self.img_feature_dim, num_bottleneck),
nn.ReLU(),
nn.Linear(num_bottleneck, self.num_class),
)
self.fc_fusion_scales += [fc_fusion]
def forward(self, input_):
record = []
# the first one is the largest scale
act_all = input_[:, self.relations_scales[0][0], :]
act_all = act_all.view(act_all.size(0), self.scales[0] * self.img_feature_dim)
act_all = self.fc_fusion_scales[0](act_all)
norm = act_all.norm(p=2, dim=-1, keepdim=True)
act_all = act_all.div(norm)
record.append(act_all)
for scale_id in range(1, len(self.scales)):
# iterate over the scales
idx_relations_randomsample = np.random.choice(len(self.relations_scales[scale_id]),
self.subsample_scales[scale_id], replace=False)
act_all = 0
for idx in idx_relations_randomsample:
act_relation = input_[:, self.relations_scales[scale_id][idx], :]
act_relation = act_relation.view(act_relation.size(0), self.scales[scale_id] * self.img_feature_dim)
act_relation = self.fc_fusion_scales[scale_id](act_relation)
act_all += act_relation
norm = act_all.norm(p=2, dim=-1, keepdim=True)
act_all = act_all.div(norm)
record.append(act_all)
act_all = torch.cat((record), 1)
return act_all
def return_relationset(self, num_frames, num_frames_relation):
return list(itertools.combinations(list(range(num_frames)), num_frames_relation))
class CEModule(nn.Module):
"""
CE Module
"""
def __init__(self, expert_dims, text_dim, use_ce, verbose, l2renorm, num_classes,
trn_config, trn_cat, use_mish, include_self, num_h_layers, num_g_layers,
disable_nan_checks, random_feats, test_caption_mode, mimic_ce_dims,
concat_experts, concat_mix_experts, freeze_weights, task,
keep_missing_modalities, vlad_feat_sizes, same_dim, use_bn_reason):
super().__init__()
modalities = list(expert_dims.keys())
self.expert_dims = expert_dims
self.modalities = modalities
self.disable_nan_checks = disable_nan_checks
self.mimic_ce_dims = mimic_ce_dims
self.concat_experts = concat_experts
self.same_dim = same_dim
self.use_mish = use_mish
self.use_bn_reason = use_bn_reason
self.num_h_layers = num_h_layers
self.num_g_layers = num_g_layers
self.include_self = include_self
self.num_classes = num_classes
self.task = task
self.vlad_feat_sizes = vlad_feat_sizes
self.concat_mix_experts = concat_mix_experts
self.test_caption_mode = test_caption_mode
self.reduce_dim = 64
self.moe_cg = ContextGating
self.freeze_weights = freeze_weights
self.random_feats = random_feats
self.use_ce = use_ce
self.verbose = verbose
self.keep_missing_modalities = keep_missing_modalities
self.l2renorm = l2renorm
self.trn_config = trn_config
self.trn_cat = trn_cat
if self.use_mish:
self.non_lin = Mish()
else:
self.non_lin = nn.ReLU()
if "retrieval" in self.task:
num_mods = len(expert_dims)
self.moe_fc = nn.Linear(text_dim, len(expert_dims))
self.moe_weights = torch.ones(1, num_mods) / num_mods
use_bns = [True for _ in self.modalities]
self.trn_list = nn.ModuleList()
self.repeat_temporal = {}
for mod in modalities:
self.repeat_temporal[mod] = 1
if self.trn_cat == 2:
for mod in self.trn_config.keys():
img_feature_dim = expert_dims[mod][0] # 365
num_frames = self.trn_config[
mod] # This is exatcly how many different attention
num_frames = 1 # mimic simple avg and max based on segments
# num_class = expert_dims[mod][0]
self.trn_list += [TemporalAttention(img_feature_dim, num_frames)]
self.repeat_temporal[mod] = num_frames + 2
elif self.trn_cat == 1:
for mod in self.trn_config.keys():
img_feature_dim = expert_dims[mod][0] # 365
num_frames = self.trn_config[mod] # hard code
num_class = expert_dims[mod][0]
self.trn_list += [
RelationModuleMultiScale_Cat(img_feature_dim, num_frames, num_class)
]
self.repeat_temporal[mod] = len(list(range(num_frames, 1, -1)))
elif self.trn_cat == 0:
for mod in self.trn_config.keys():
img_feature_dim = expert_dims[mod][0] # 365
num_frames = self.trn_config[mod] # hard code
num_class = expert_dims[mod][0]
self.trn_list += [
RelationModuleMultiScale(img_feature_dim, num_frames,
num_class)
]
else:
raise NotImplementedError()
in_dims = [expert_dims[mod][0] * self.repeat_temporal[mod] for mod in modalities]
agg_dims = [expert_dims[mod][1] * self.repeat_temporal[mod] for mod in modalities]
if self.use_ce or self.mimic_ce_dims:
dim_reducers = [ReduceDim(in_dim, same_dim) for in_dim in in_dims]
self.video_dim_reduce = nn.ModuleList(dim_reducers)
if self.use_ce:
# The g_reason module has a first layer that is specific to the design choice
# (e.g. triplet vs pairwise), then a shared component which is common to all
# designs.
if self.use_ce in {"pairwise", "pairwise-star", "triplet"}:
num_inputs = 3 if self.use_ce == "triplet" else 2
self.g_reason_1 = nn.Linear(same_dim * num_inputs, same_dim)
elif self.use_ce == "pairwise-star-specific":
num_inputs = 2
g_reason_unshared_weights = [G_reason(same_dim, num_inputs, self.non_lin)
for mod in modalities]
self.g_reason_unshared_weights = nn.ModuleList(g_reason_unshared_weights)
elif self.use_ce in {"pairwise-star-tensor"}:
reduce_dim = self.reduce_dim
self.dim_reduce = nn.Linear(same_dim, reduce_dim)
self.g_reason_1 = nn.Linear(self.reduce_dim * reduce_dim, same_dim)
else:
raise ValueError(f"unrecognised CE config: {self.use_ce}")
g_reason_shared = []
for _ in range(self.num_g_layers - 1):
if self.use_bn_reason:
g_reason_shared.append(nn.BatchNorm1d(same_dim))
g_reason_shared.append(self.non_lin)
g_reason_shared.append(nn.Linear(same_dim, same_dim))
self.g_reason_shared = nn.Sequential(*g_reason_shared)
h_reason = []
for _ in range(self.num_h_layers):
if self.use_bn_reason:
h_reason.append(nn.BatchNorm1d(same_dim))
h_reason.append(self.non_lin)
h_reason.append(nn.Linear(same_dim, same_dim))
self.h_reason = nn.Sequential(*h_reason)
gated_vid_embds = [GatedEmbeddingUnitReasoning(same_dim) for _ in in_dims]
text_out_dims = [same_dim for _ in agg_dims]
elif self.mimic_ce_dims: # ablation study
gated_vid_embds = [MimicCEGatedEmbeddingUnit(same_dim, same_dim, use_bn=True)
for _ in modalities]
text_out_dims = [same_dim for _ in agg_dims]
elif self.concat_mix_experts: # ablation study
# use a single large GEU to mix the experts - the output will be the sum
# of the aggregation sizes
in_dim, out_dim = sum(in_dims), sum(agg_dims)
gated_vid_embds = [GatedEmbeddingUnit(in_dim, out_dim, use_bn=True)]
elif self.concat_experts: # ablation study
# We do not use learnable parameters for the video combination, (we simply
# use a high dimensional inner product).
gated_vid_embds = []
else:
gated_vid_embds = [GatedEmbeddingUnit(in_dim, dim, use_bn) for
in_dim, dim, use_bn in zip(in_dims, agg_dims, use_bns)]
text_out_dims = agg_dims
self.video_GU = nn.ModuleList(gated_vid_embds) # pylint: disable=invalid-name
if "retrieval" in self.task:
if self.concat_experts:
gated_text_embds = [nn.Sequential()]
elif self.concat_mix_experts:
# As with the video inputs, we similiarly use a single large GEU for the
# text embedding
gated_text_embds = [GatedEmbeddingUnit(text_dim, sum(agg_dims),
use_bn=True)]
else:
gated_text_embds = [GatedEmbeddingUnit(text_dim, dim, use_bn=True) for
dim in text_out_dims]
self.text_GU = nn.ModuleList(gated_text_embds) # pylint: disable=invalid-name
else:
total_dim = 0
for mod in self.expert_dims.keys():
total_dim += self.expert_dims[mod][1] * self.repeat_temporal[mod]
self.classifier = nn.Linear(total_dim, self.num_classes)
def compute_moe_weights(self, text, ind):
_ = ind
# compute weights for all captions (including when assigned K captions to
# the same video)
bb, kk, dd = text.shape
mm = len(self.modalities)
msg = f"expected between 1 and 10 modalities, found {mm} ({self.modalities})"
assert 1 <= mm <= 10, msg
# Treat each caption independently in the softmax (which runs over modalities)
text = text.view(bb * kk, dd)
if self.freeze_weights:
moe_weights = self.moe_weights.repeat(bb, kk, 1)
if text.is_cuda:
moe_weights = moe_weights.cuda()
else:
# if False:
# print("USING BIGGER WEIGHT PREDS")
# moe_weights = self.moe_fc_bottleneck1(text)
# moe_weights = self.moe_cg(moe_weights)
# moe_weights = self.moe_fc_proj(moe_weights)
# moe_weights = moe_weights * 1
# else:
moe_weights = self.moe_fc(text) # BK x D -> BK x M
moe_weights = F.softmax(moe_weights, dim=1)
moe_weights = moe_weights.view(bb, kk, mm)
if self.verbose:
print("--------------------------------")
for idx, key in enumerate(self.modalities):
msg = "{}: mean: {:.3f}, std: {:.3f}, min: {:.3f}, max: {:.3f}"
msg = msg.format(
key,
moe_weights[:, :, idx].mean().item(),
moe_weights[:, :, idx].std().item(),
moe_weights[:, :, idx].min().item(),
moe_weights[:, :, idx].max().item(),
)
print(msg)
return moe_weights
def forward(self, text, experts, ind, raw_captions):
"""Compute joint embeddings and, if requested, a confusion matrix between
video and text representations in the minibatch.
Notation: B = batch size, M = number of modalities
"""
if "retrieval" in self.task:
# Pass text embeddings through gated units
text_embd = {}
# Unroll repeated captions into present minibatch
bb, captions_per_video, feat_dim = text.size()
text = text.view(bb * captions_per_video, feat_dim)
for modality, layer in zip(self.modalities, self.text_GU):
# NOTE: Due to the batch norm, the gated units are sensitive to passing
# in a lot of zeroes, so we do the masking step after the forwards pass
text_ = layer(text)
# We always assume that text is available for retrieval
text_ = text_.view(bb, captions_per_video, -1)
if "text" in self.random_feats:
text_ = torch.rand_like(text_)
text_embd[modality] = text_
text = text.view(bb, captions_per_video, -1)
# vladded nans are handled earlier (during pooling)
# We also avoid zeroing random features, since this will leak information
# exclude = list(self.vlad_feat_sizes.keys()) + list(self.random_feats)
# experts = self.mask_missing_embeddings(experts, ind, exclude=exclude)
# MOE weights computation + normalization - note that we use the first caption
# sample to predict the weights
moe_weights = self.compute_moe_weights(text, ind=ind)
if self.l2renorm:
for modality in self.modalities:
norm = experts[modality].norm(p=2, dim=-1, keepdim=True)
experts[modality] = experts[modality].div(norm)
for modality, layer in zip(self.modalities, self.trn_list):
experts[modality] = layer(experts[modality])
if hasattr(self, "video_dim_reduce"):
# Embed all features to a common dimension
for modality, layer in zip(self.modalities, self.video_dim_reduce):
experts[modality] = layer(experts[modality])
if self.use_ce:
dev = experts[self.modalities[0]].device
if self.include_self:
all_combinations = list(itertools.product(experts, repeat=2))
else:
all_combinations = list(itertools.permutations(experts, 2))
assert len(self.modalities) > 1, "use_ce requires multiple modalities"
if self.use_ce in {"pairwise-star", "pairwise-star-specific",
"pairwise-star-tensor"}:
sum_all = 0
sum_ind = 0
for mod0 in experts.keys():
sum_all += (experts[mod0] * ind[mod0].float().to(dev).unsqueeze(1))
sum_ind += ind[mod0].float().to(dev).unsqueeze(1)
avg_modality = sum_all / sum_ind
for ii, l in enumerate(self.video_GU):
mask_num = 0
curr_mask = 0
temp_dict = {}
avai_dict = {}
curr_modality = self.modalities[ii]
if self.use_ce == "pairwise-star":
fused = torch.cat((experts[curr_modality], avg_modality), 1) # -> B x 2D
temp = self.g_reason_1(fused) # B x 2D -> B x D
temp = self.g_reason_shared(temp) # B x D -> B x D
curr_mask = temp * ind[curr_modality].float().to(dev).unsqueeze(1)
elif self.use_ce == "pairwise-star-specific":
fused = torch.cat((experts[curr_modality], avg_modality), 1) # -> B x 2D
temp = self.g_reason_unshared_weights[ii](fused)
temp = self.g_reason_shared(temp) # B x D -> B x D
curr_mask = temp * ind[curr_modality].float().to(dev).unsqueeze(1)
elif self.use_ce == "pairwise-star-tensor":
mod0_reduce = self.dim_reduce(experts[curr_modality])
mod0_reduce = mod0_reduce.unsqueeze(2) # B x reduced_dim x1
mod1_reduce = self.dim_reduce(avg_modality)
mod1_reduce = mod1_reduce.unsqueeze(1) # B x1 x reduced_dim
flat_dim = self.reduce_dim * self.reduce_dim
fused = torch.matmul(mod0_reduce, mod1_reduce).view(-1, flat_dim)
temp = self.g_reason_1(fused) # B x 2D -> B x D
temp = self.g_reason_shared(temp) # B x D -> B x D
curr_mask = temp * ind[curr_modality].float().to(dev).unsqueeze(1)
elif self.use_ce in {"pairwise", "triplet"}:
for modality_pair in all_combinations:
mod0, mod1 = modality_pair
if self.use_ce == "pairwise":
if mod0 == curr_modality:
new_key = f"{mod0}_{mod1}"
fused = torch.cat((experts[mod0], experts[mod1]), 1)
temp = self.g_reason_1(fused) # B x 2D -> B x D
temp = self.g_reason_shared(temp)
temp_dict[new_key] = temp
avail = (ind[mod0].float() * ind[mod1].float())
avai_dict[new_key] = avail.to(dev)
elif self.use_ce == "triplet":
if (curr_modality not in {mod0, mod1}) or self.include_self:
new_key = f"{curr_modality}_{mod0}_{mod1}"
fused = torch.cat((experts[curr_modality], experts[mod0],
experts[mod1]), 1) # -> B x 2D
temp = self.g_reason_1(fused) # B x 2D -> B x D
temp = self.g_reason_shared(temp)
temp_dict[new_key] = temp
avail = (ind[curr_modality].float() * ind[mod0].float() *
ind[mod1].float()).to(dev)
avai_dict[new_key] = avail
# Combine the paired features into a mask through elementwise sum
for mm, value in temp_dict.items():
curr_mask += value * avai_dict[mm].unsqueeze(1)
mask_num += avai_dict[mm]
curr_mask = torch.div(curr_mask, (mask_num + 0.00000000001).unsqueeze(1))
else:
raise ValueError(f"Unknown CE mechanism: {self.use_ce}")
curr_mask = self.h_reason(curr_mask)
experts[curr_modality] = l(experts[curr_modality], curr_mask)
elif self.concat_mix_experts:
concatenated = torch.cat(tuple(experts.values()), dim=1)
vid_embd_ = self.video_GU[0](concatenated)
text_embd_ = text_embd[self.modalities[0]]
text_embd_ = text_embd_.view(-1, text_embd_.shape[-1])
elif self.concat_experts:
vid_embd_ = torch.cat(tuple(experts.values()), dim=1)
text_embd_ = text_embd[self.modalities[0]]
text_embd_ = text_embd_.view(-1, text_embd_.shape[-1])
else:
for modality, layer in zip(self.modalities, self.video_GU):
experts[modality] = layer(experts[modality])
if self.training:
merge_caption_similiarities = "avg"
else:
merge_caption_similiarities = self.test_caption_mode
if self.task == "classification":
# for modality, layer in zip(self.modalities, self.video_dim_reduce_later):
# attempt to perform affordable classifier, might be removed later
# experts[modality] = layer(experts[modality])
concatenated = torch.cat(tuple(experts.values()), dim=1)
preds = self.classifier(concatenated)
return {"modalities": self.modalities, "class_preds": preds}
elif self.concat_experts or self.concat_mix_experts:
# zero pad to accommodate mismatch in sizes (after first setting the number
# of VLAD clusters for the text to get the two vectors as close as possible
# in size)
if text_embd_.shape[1] > vid_embd_.shape[1]:
sz = (vid_embd_.shape[0], text_embd_.shape[1])
dtype, device = text_embd_.dtype, text_embd_.device
vid_embd_padded = torch.zeros(size=sz, dtype=dtype, device=device)
# try:
# vid_embd_padded[:, :vid_embd_.shape[1]] = vid_embd_
# except:
# import ipdb; ipdb.set_trace()
vid_embd_ = vid_embd_padded
else:
sz = (text_embd_.shape[0], vid_embd_.shape[1])
dtype, device = text_embd_.dtype, text_embd_.device
text_embd_padded = torch.zeros(size=sz, dtype=dtype, device=device)
text_embd_padded[:, :text_embd_.shape[1]] = text_embd_
text_embd_ = text_embd_padded
cross_view_conf_matrix = torch.matmul(text_embd_, vid_embd_.t())
elif self.task == "compute_video_embeddings":
return {"modalities": self.modalities, "embeddings": experts}
else:
cross_view_conf_matrix = sharded_cross_view_inner_product(
ind=ind,
vid_embds=experts,
text_embds=text_embd,
keep_missing_modalities=self.keep_missing_modalities,
l2renorm=self.l2renorm,
text_weights=moe_weights,
subspaces=self.modalities,
raw_captions=raw_captions,
merge_caption_similiarities=merge_caption_similiarities,
)
return {
"modalities": self.modalities,
"cross_view_conf_matrix": cross_view_conf_matrix,
"text_embds": text_embd,
"vid_embds": experts,
}
class GatedEmbeddingUnit(nn.Module):
"""
GatedEmbeddingUnit
"""
def __init__(self, input_dimension, output_dimension, use_bn):
super().__init__()
self.fc = nn.Linear(input_dimension, output_dimension)
self.cg = ContextGating(output_dimension, add_batch_norm=use_bn)
def forward(self, x):
x = self.fc(x)
x = self.cg(x)
x = F.normalize(x)
return x
class MimicCEGatedEmbeddingUnit(nn.Module):
def __init__(self, input_dimension, output_dimension, use_bn):
super().__init__()
_ = output_dimension
self.cg = ContextGating(input_dimension, add_batch_norm=use_bn)
def forward(self, x):
x = self.cg(x)
x = F.normalize(x)
return x
class ReduceDim(nn.Module):
"""
ReduceDim Module
"""
def __init__(self, input_dimension, output_dimension):
super().__init__()
self.fc = nn.Linear(input_dimension, output_dimension)
# self.fc = nn.Linear(input_dimension, 512)
# self.fc2 = nn.Linear(512, output_dimension)
def forward(self, x):
x = self.fc(x)
# x = self.fc2(F.relu(x))
x = F.normalize(x)
return x
class ContextGating(nn.Module):
"""
ContextGating Module
"""
def __init__(self, dimension, add_batch_norm=True):
super().__init__()
self.fc = nn.Linear(dimension, dimension)
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(dimension)
def forward(self, x):
x1 = self.fc(x)
if self.add_batch_norm:
x1 = self.batch_norm(x1)
x = torch.cat((x, x1), 1)
return F.glu(x, 1)
class GatedEmbeddingUnitReasoning(nn.Module):
def __init__(self, output_dimension):
super().__init__()
self.cg = ContextGatingReasoning(output_dimension)
def forward(self, x, mask):
x = self.cg(x, mask)
x = F.normalize(x)
return x
class SpatialMLP(nn.Module):
def __init__(self, dimension):
super().__init__()
self.cg1 = ContextGating(dimension)
self.cg2 = ContextGating(dimension)
def forward(self, x):
x = self.cg1(x)
return self.cg2(x)
class ContextGatingReasoning(nn.Module):
"""
ContextGatingReasoning
"""
def __init__(self, dimension, add_batch_norm=True):
super().__init__()
self.fc = nn.Linear(dimension, dimension)
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(dimension)
self.batch_norm2 = nn.BatchNorm1d(dimension)
def forward(self, x, x1):
x2 = self.fc(x)
if self.add_batch_norm:
x1 = self.batch_norm(x1)
x2 = self.batch_norm2(x2)
t = x1 + x2
x = torch.cat((x, t), 1)
return F.glu(x, 1)
class G_reason(nn.Module): # pylint: disable=invalid-name
"""
G_reason Module
"""
def __init__(self, same_dim, num_inputs, non_lin):
super().__init__()
self.g_reason_1_specific = nn.Linear(same_dim * num_inputs, same_dim)
self.g_reason_2_specific = nn.Linear(same_dim, same_dim)
self.non_lin = non_lin
def forward(self, x):
x = self.g_reason_1_specific(x) # B x 2D -> B x D
x = self.non_lin(x)
x = self.g_reason_2_specific(x)
return x
def sharded_cross_view_inner_product(vid_embds, text_embds, text_weights,
subspaces, l2renorm, ind,
keep_missing_modalities,
merge_caption_similiarities="avg", tol=1E-5,
raw_captions=None):
"""Compute a similarity matrix from sharded vectors.
Args:
embds1 (dict[str:torch.Tensor]): the set of sub-embeddings that, when
concatenated, form the whole. The ith shard has shape `B x K x F_i`
(i.e. they can differ in the last dimension).
embds2 (dict[str:torch.Tensor]): same format.
weights2 (torch.Tensor): weights for the shards in `embds2`.
l2norm (bool::True): whether to l2 renormalize the full embeddings.
Returns:
(torch.tensor): similarity matrix of size `BK x BK`.
NOTE: If multiple captions are provided, we can aggregate their similarities to
provide a single video-text similarity score.
"""
_ = raw_captions
bb = vid_embds[subspaces[0]].size(0)
tt, num_caps, _ = text_embds[subspaces[0]].size()
device = vid_embds[subspaces[0]].device
# unroll separate captions onto first dimension and treat them separately
sims = torch.zeros(tt * num_caps, bb, device=device)
text_weights = text_weights.view(tt * num_caps, -1)