-
Notifications
You must be signed in to change notification settings - Fork 0
/
MoCo+McGIP.py
172 lines (141 loc) · 6.03 KB
/
MoCo+McGIP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import numpy as np
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather)
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
@ALGORITHMS.register_module()
class MoCo_gaze(BaseModel):
"""MoCo.
Implementation of MoCo+McGIP
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None.
head (dict): Config dict for module of loss functions.
Defaults to None.
queue_len (int): Number of negative keys maintained in the queue.
Defaults to 1280.
feat_dim (int): Dimension of compact feature vectors. Defaults to 256.
momentum (float): Momentum coefficient for the momentum-updated
encoder. Defaults to 0.999.
threshold (float): threshold for the construction of positive pairs. Defaults to 0.7
relation (string): The file containing gaze simialrity.
"""
def __init__(self,
backbone,
neck=None,
head=None,
queue_len=1280,
feat_dim=256,
momentum=0.999,
init_cfg=None,
threshold=0.7, relation='./relation_multimatch.npy',
**kwargs):
super(MoCo_gaze, self).__init__(init_cfg)
assert neck is not None
self.encoder_q = nn.Sequential(
build_backbone(backbone), build_neck(neck))
self.encoder_k = nn.Sequential(
build_backbone(backbone), build_neck(neck))
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
self.backbone = self.encoder_q[0]
self.neck = self.encoder_q[1]
assert head is not None
self.head = build_head(head)
self.queue_len = queue_len
self.momentum = momentum
# create a queue saving image representations
self.register_buffer('queue', torch.randn(feat_dim, queue_len))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
# create a queue saving image idx, in order to get their gaze similarity from self.relation with ease
self.register_buffer('queue_idx', torch.randint(0, 700, (queue_len, 1),dtype=torch.long))
self.threshold = threshold
self.relation = np.load(relation)
def _create_buffer(self, N, idx_list):
labels = torch.zeros([N,self.queue_len],dtype=torch.long)
for i in range(N):
idx = int(idx_list[i].item())
for j in range(self.queue_len):
jdx = int(self.queue_idx[j].item())
if (i == j):
pass
else:
sim = self.relation[idx][jdx]
if (sim > self.threshold):
labels[i][j] = 1
labels = labels.cuda()
return labels
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update of the key encoder."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + \
param_q.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys,idx_list):
"""Update queue."""
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_len % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
self.queue_idx[ptr:ptr + batch_size,0]=idx_list
ptr = (ptr + batch_size) % self.queue_len # move pointer
self.queue_ptr[0] = ptr
def extract_feat(self, img):
"""Function to extract features from backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
Returns:
tuple[Tensor]: backbone outputs.
"""
x = self.backbone(img)
return x
def forward_train(self, img,idx, **kwargs):
"""Forward computation during training.
Args:
img (list[Tensor]): A list of input images with shape
(N, C, H, W). Typically these should be mean centered
and std scaled.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert isinstance(img,list)
idx_list=idx
im_q = img[0]
im_k = img[1]
# compute query features
q = self.encoder_q(im_q)[0] # queries: NxC
q = nn.functional.normalize(q, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
# update the key encoder
self._momentum_update_key_encoder()
# shuffle for making use of BN
im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
k = self.encoder_k(im_k)[0] # keys: NxC
k = nn.functional.normalize(k, dim=1)
# undo shuffle
k = batch_unshuffle_ddp(k, idx_unshuffle)
labels= self._create_buffer(q.shape[0], idx_list)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
losses = self.head(l_pos, l_neg,labels)
# update the queue
self._dequeue_and_enqueue(k,idx_list)
return losses